Skip to content

Commit

Permalink
Instead of enum dispatch, use generics and dynamic dispatch to build …
Browse files Browse the repository at this point in the history
…up the basic stream stack as already done for the rest of the I/O stack.
  • Loading branch information
adamreichold committed Nov 27, 2023
1 parent 56a2f4b commit dcea9c2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "zeptohttpc"
description = "minimal HTTP client using http and httparse crates"
version = "0.8.0"
version = "0.8.1"
authors = ["Adam Reichold <[email protected]>"]
edition = "2018"
rust-version = "1.51"
Expand Down
139 changes: 72 additions & 67 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,26 @@ use webpki_roots::TLS_SERVER_ROOTS;

use super::{happy_eyeballs::connect, timeout::Timeout, Error, Options};

pub enum Stream {
Tcp(TcpStream),
TcpWithTimeout(TcpStream, Timeout),
#[cfg(feature = "native-tls")]
NativeTls(TlsStream<TcpStream>),
#[cfg(feature = "native-tls")]
NativeTlsWithTimeout(TlsStream<TcpStream>, Timeout),
#[cfg(feature = "rustls")]
Rustls(Box<StreamOwned<ClientConnection, TcpStream>>),
#[cfg(feature = "rustls")]
RustlsWithTimeout(Box<StreamOwned<ClientConnection, TcpStream>>, Timeout),
pub struct Stream(Box<dyn Inner>);

trait Inner: Read + Write + Send {}

impl<S> Inner for S where S: Read + Write + Send {}

impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.0.read(buf)
}
}

impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> IoResult<()> {
self.0.flush()
}
}

impl Stream {
Expand All @@ -63,40 +72,66 @@ impl Stream {
) -> Result<Self, Error> {
let stream = connect(host, port, opts)?;

match opts.deadline {
let inner: Box<dyn Inner> = match opts.deadline {
#[cfg(feature = "native-tls")]
None if scheme == &Scheme::HTTPS => {
let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?;

Ok(Self::NativeTls(stream))
Box::new(stream)
}
#[cfg(feature = "rustls")]
None if scheme == &Scheme::HTTPS => {
let stream = perform_rustls_handshake(stream, host, opts.client_config)?;

Ok(Self::Rustls(Box::new(stream)))
Box::new(HandleCloseNotify(stream))
}
None => Ok(Self::Tcp(stream)),
None => Box::new(stream),
#[cfg(feature = "native-tls")]
Some(deadline) if scheme == &Scheme::HTTPS => {
let timeout = Timeout::start(&stream, deadline)?;
let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?;

Ok(Self::NativeTlsWithTimeout(stream, timeout))
Box::new(WithTimeout(stream, timeout))
}
#[cfg(feature = "rustls")]
Some(deadline) if scheme == &Scheme::HTTPS => {
let timeout = Timeout::start(&stream, deadline)?;
let stream = perform_rustls_handshake(stream, host, opts.client_config)?;

Ok(Self::RustlsWithTimeout(Box::new(stream), timeout))
Box::new(WithTimeout(HandleCloseNotify(stream), timeout))
}
Some(deadline) => {
let timeout = Timeout::start(&stream, deadline)?;

Ok(Self::TcpWithTimeout(stream, timeout))
Box::new(WithTimeout(stream, timeout))
}
}
};

Ok(Self(inner))
}
}

struct WithTimeout<S>(S, Timeout);

impl<S> Read for WithTimeout<S>
where
S: Read,
{
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.1.read(&mut self.0, buf)
}
}

impl<S> Write for WithTimeout<S>
where
S: Write,
{
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> IoResult<()> {
self.0.flush()
}
}

Expand Down Expand Up @@ -178,63 +213,33 @@ fn perform_rustls_handshake(
Ok(StreamOwned::new(conn, stream))
}

impl Read for Stream {
#[cfg(feature = "rustls")]
struct HandleCloseNotify(StreamOwned<ClientConnection, TcpStream>);

#[cfg(feature = "rustls")]
impl Read for HandleCloseNotify {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::TcpWithTimeout(stream, timeout) => timeout.read(stream, buf),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) => stream.read(buf),
#[cfg(feature = "native-tls")]
Self::NativeTlsWithTimeout(stream, timeout) => timeout.read(stream, buf),
#[cfg(feature = "rustls")]
Self::Rustls(stream) => {
let res = stream.read(buf);
handle_close_notify(res, stream)
}
#[cfg(feature = "rustls")]
Self::RustlsWithTimeout(stream, timeout) => {
let res = timeout.read(stream, buf);
handle_close_notify(res, stream)
let res = self.0.read(buf);

match res {
Err(err) if err.kind() == ConnectionAborted => {
self.0.conn.send_close_notify();
self.0.conn.complete_io(&mut self.0.sock)?;

Ok(0)
}
res => res,
}
}
}

impl Write for Stream {
#[cfg(feature = "rustls")]
impl Write for HandleCloseNotify {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self {
Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.write(buf),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.write(buf),
#[cfg(feature = "rustls")]
Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.write(buf),
}
self.0.write(buf)
}

fn flush(&mut self) -> IoResult<()> {
match self {
Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.flush(),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.flush(),
#[cfg(feature = "rustls")]
Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.flush(),
}
}
}

#[cfg(feature = "rustls")]
fn handle_close_notify(
res: IoResult<usize>,
stream: &mut StreamOwned<ClientConnection, TcpStream>,
) -> IoResult<usize> {
match res {
Err(err) if err.kind() == ConnectionAborted => {
stream.conn.send_close_notify();
stream.conn.complete_io(&mut stream.sock)?;

Ok(0)
}
res => res,
self.0.flush()
}
}

0 comments on commit dcea9c2

Please sign in to comment.