diff --git a/Cargo.toml b/Cargo.toml index dff51e3..332a4a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "zeptohttpc" description = "minimal HTTP client using http and httparse crates" -version = "0.8.0" +version = "0.8.1" authors = ["Adam Reichold "] edition = "2018" rust-version = "1.51" diff --git a/src/stream.rs b/src/stream.rs index 2642976..570cc14 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -41,17 +41,26 @@ use webpki_roots::TLS_SERVER_ROOTS; use super::{happy_eyeballs::connect, timeout::Timeout, Error, Options}; -pub enum Stream { - Tcp(TcpStream), - TcpWithTimeout(TcpStream, Timeout), - #[cfg(feature = "native-tls")] - NativeTls(TlsStream), - #[cfg(feature = "native-tls")] - NativeTlsWithTimeout(TlsStream, Timeout), - #[cfg(feature = "rustls")] - Rustls(Box>), - #[cfg(feature = "rustls")] - RustlsWithTimeout(Box>, Timeout), +pub struct Stream(Box); + +trait Inner: Read + Write + Send {} + +impl Inner for S where S: Read + Write + Send {} + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + self.0.read(buf) + } +} + +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> IoResult { + self.0.write(buf) + } + + fn flush(&mut self) -> IoResult<()> { + self.0.flush() + } } impl Stream { @@ -63,40 +72,66 @@ impl Stream { ) -> Result { let stream = connect(host, port, opts)?; - match opts.deadline { + let inner: Box = match opts.deadline { #[cfg(feature = "native-tls")] None if scheme == &Scheme::HTTPS => { let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?; - Ok(Self::NativeTls(stream)) + Box::new(stream) } #[cfg(feature = "rustls")] None if scheme == &Scheme::HTTPS => { let stream = perform_rustls_handshake(stream, host, opts.client_config)?; - Ok(Self::Rustls(Box::new(stream))) + Box::new(HandleCloseNotify(stream)) } - None => Ok(Self::Tcp(stream)), + None => Box::new(stream), #[cfg(feature = "native-tls")] Some(deadline) if scheme == &Scheme::HTTPS => { let timeout = Timeout::start(&stream, deadline)?; let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?; - Ok(Self::NativeTlsWithTimeout(stream, timeout)) + Box::new(WithTimeout(stream, timeout)) } #[cfg(feature = "rustls")] Some(deadline) if scheme == &Scheme::HTTPS => { let timeout = Timeout::start(&stream, deadline)?; let stream = perform_rustls_handshake(stream, host, opts.client_config)?; - Ok(Self::RustlsWithTimeout(Box::new(stream), timeout)) + Box::new(WithTimeout(HandleCloseNotify(stream), timeout)) } Some(deadline) => { let timeout = Timeout::start(&stream, deadline)?; - Ok(Self::TcpWithTimeout(stream, timeout)) + Box::new(WithTimeout(stream, timeout)) } - } + }; + + Ok(Self(inner)) + } +} + +struct WithTimeout(S, Timeout); + +impl Read for WithTimeout +where + S: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> IoResult { + self.1.read(&mut self.0, buf) + } +} + +impl Write for WithTimeout +where + S: Write, +{ + fn write(&mut self, buf: &[u8]) -> IoResult { + self.0.write(buf) + } + + fn flush(&mut self) -> IoResult<()> { + self.0.flush() } } @@ -178,63 +213,33 @@ fn perform_rustls_handshake( Ok(StreamOwned::new(conn, stream)) } -impl Read for Stream { +#[cfg(feature = "rustls")] +struct HandleCloseNotify(StreamOwned); + +#[cfg(feature = "rustls")] +impl Read for HandleCloseNotify { fn read(&mut self, buf: &mut [u8]) -> IoResult { - match self { - Self::Tcp(stream) => stream.read(buf), - Self::TcpWithTimeout(stream, timeout) => timeout.read(stream, buf), - #[cfg(feature = "native-tls")] - Self::NativeTls(stream) => stream.read(buf), - #[cfg(feature = "native-tls")] - Self::NativeTlsWithTimeout(stream, timeout) => timeout.read(stream, buf), - #[cfg(feature = "rustls")] - Self::Rustls(stream) => { - let res = stream.read(buf); - handle_close_notify(res, stream) - } - #[cfg(feature = "rustls")] - Self::RustlsWithTimeout(stream, timeout) => { - let res = timeout.read(stream, buf); - handle_close_notify(res, stream) + let res = self.0.read(buf); + + match res { + Err(err) if err.kind() == ConnectionAborted => { + self.0.conn.send_close_notify(); + self.0.conn.complete_io(&mut self.0.sock)?; + + Ok(0) } + res => res, } } } -impl Write for Stream { +#[cfg(feature = "rustls")] +impl Write for HandleCloseNotify { fn write(&mut self, buf: &[u8]) -> IoResult { - match self { - Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.write(buf), - #[cfg(feature = "native-tls")] - Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.write(buf), - #[cfg(feature = "rustls")] - Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.write(buf), - } + self.0.write(buf) } fn flush(&mut self) -> IoResult<()> { - match self { - Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.flush(), - #[cfg(feature = "native-tls")] - Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.flush(), - #[cfg(feature = "rustls")] - Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.flush(), - } - } -} - -#[cfg(feature = "rustls")] -fn handle_close_notify( - res: IoResult, - stream: &mut StreamOwned, -) -> IoResult { - match res { - Err(err) if err.kind() == ConnectionAborted => { - stream.conn.send_close_notify(); - stream.conn.complete_io(&mut stream.sock)?; - - Ok(0) - } - res => res, + self.0.flush() } }