diff --git a/quic/s2n-quic/Cargo.toml b/quic/s2n-quic/Cargo.toml index de89c9f6ef..98ac649f27 100644 --- a/quic/s2n-quic/Cargo.toml +++ b/quic/s2n-quic/Cargo.toml @@ -77,7 +77,7 @@ s2n-quic-rustls = { version = "=0.44.1", path = "../s2n-quic-rustls", optional = s2n-quic-tls = { version = "=0.44.1", path = "../s2n-quic-tls", optional = true } s2n-quic-tls-default = { version = "=0.44.1", path = "../s2n-quic-tls-default", optional = true } s2n-quic-transport = { version = "=0.44.1", path = "../s2n-quic-transport" } -tokio = { version = "1", default-features = false } +tokio = { version = "1", default-features = false, features = ["sync"] } zerocopy = { version = "0.7", optional = true, features = ["derive"] } zeroize = { version = "1", optional = true, default-features = false } diff --git a/quic/s2n-quic/src/provider/dc/confirm.rs b/quic/s2n-quic/src/provider/dc/confirm.rs index c76ab3bd67..6a3c9e6333 100644 --- a/quic/s2n-quic/src/provider/dc/confirm.rs +++ b/quic/s2n-quic/src/provider/dc/confirm.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use crate::Connection; -use core::task::{Context, Poll, Waker}; use s2n_quic_core::{ connection, connection::Error, @@ -13,6 +12,7 @@ use s2n_quic_core::{ }, }; use std::io; +use tokio::sync::watch; /// `event::Subscriber` used for ensuring an s2n-quic client or server negotiating dc /// waits for the dc handshake to complete @@ -21,58 +21,54 @@ impl ConfirmComplete { /// Blocks the task until the provided connection has either completed the dc handshake or closed /// with an error pub async fn wait_ready(conn: &mut Connection) -> io::Result<()> { - core::future::poll_fn(|cx| { - conn.query_event_context_mut(|context: &mut ConfirmContext| context.poll_ready(cx)) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? - }) - .await + let mut receiver = conn + .query_event_context_mut(|context: &mut ConfirmContext| context.sender.subscribe()) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + loop { + match &*receiver.borrow_and_update() { + // if we're ready or have errored then let the application know + State::Ready => return Ok(()), + State::Failed(error) => return Err((*error).into()), + State::Waiting(_) => {} + } + + if receiver.changed().await.is_err() { + return Err(io::Error::new( + io::ErrorKind::Other, + "never reached terminal state", + )); + } + } } } -#[derive(Default)] pub struct ConfirmContext { - waker: Option, - state: State, + sender: watch::Sender, +} + +impl Default for ConfirmContext { + fn default() -> Self { + let (sender, _receiver) = watch::channel(State::default()); + Self { sender } + } } impl ConfirmContext { /// Updates the state on the context fn update(&mut self, state: State) { - self.state = state; - - // notify the application that the state was updated - self.wake(); - } - - /// Polls the context for handshake completion - fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - match self.state { - // if we're ready or have errored then let the application know - State::Ready => Poll::Ready(Ok(())), - State::Failed(error) => Poll::Ready(Err(error.into())), - State::Waiting(_) => { - // store the waker so we can notify the application of state updates - self.waker = Some(cx.waker().clone()); - Poll::Pending - } - } - } - - /// notify the application of a state update - fn wake(&mut self) { - if let Some(waker) = self.waker.take() { - waker.wake(); - } + self.sender.send_replace(state); } } impl Drop for ConfirmContext { // make sure the application is notified that we're closing the connection fn drop(&mut self) { - if matches!(self.state, State::Waiting(_)) { - self.state = State::Failed(connection::Error::unspecified()); - } - self.wake(); + self.sender.send_modify(|state| { + if matches!(state, State::Waiting(_)) { + *state = State::Failed(connection::Error::unspecified()); + } + }); } } @@ -107,14 +103,14 @@ impl Subscriber for ConfirmComplete { meta: &ConnectionMeta, event: &events::ConnectionClosed, ) { - ensure!(matches!(context.state, State::Waiting(_))); - - match (&meta.endpoint_type, event.error, &context.state) { - ( - EndpointType::Server { .. }, - Error::Closed { .. }, - State::Waiting(Some(DcState::PathSecretsReady { .. })), - ) => { + ensure!(matches!(*context.sender.borrow(), State::Waiting(_))); + let is_ready = matches!( + *context.sender.borrow(), + State::Waiting(Some(DcState::PathSecretsReady { .. })) + ); + + match (&meta.endpoint_type, event.error, is_ready) { + (EndpointType::Server { .. }, Error::Closed { .. }, true) => { // The client may close the connection immediately after the dc handshake completes, // before it sends acknowledgement of the server's DC_STATELESS_RESET_TOKENS. // Since the server has already moved into the PathSecretsReady state, this can be considered @@ -132,7 +128,7 @@ impl Subscriber for ConfirmComplete { _meta: &ConnectionMeta, event: &events::DcStateChanged, ) { - ensure!(matches!(context.state, State::Waiting(_))); + ensure!(matches!(*context.sender.borrow(), State::Waiting(_))); match event.state { DcState::NoVersionNegotiated { .. } => context.update(State::Failed( diff --git a/quic/s2n-quic/src/tests/dc.rs b/quic/s2n-quic/src/tests/dc.rs index a49846030c..e2f450feb4 100644 --- a/quic/s2n-quic/src/tests/dc.rs +++ b/quic/s2n-quic/src/tests/dc.rs @@ -69,7 +69,7 @@ fn dc_handshake_self_test() -> Result<()> { .with_tls(certificates::CERT_PEM)? .with_dc(MockDcEndpoint::new(&CLIENT_TOKENS))?; - self_test(server, client, None, None)?; + self_test(server, client, true, None, None)?; Ok(()) } @@ -114,7 +114,7 @@ fn dc_mtls_handshake_self_test() -> Result<()> { .with_tls(client_tls)? .with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?; - self_test(server, client, None, None)?; + self_test(server, client, true, None, None)?; Ok(()) } @@ -143,7 +143,7 @@ fn dc_mtls_handshake_auth_failure_self_test() -> Result<()> { } .into(); - self_test(server, client, Some(expected_client_error), None)?; + self_test(server, client, true, Some(expected_client_error), None)?; Ok(()) } @@ -181,6 +181,7 @@ fn dc_mtls_handshake_server_not_supported_self_test() -> Result<()> { self_test( server, client, + true, Some(connection::Error::invalid_configuration( "peer does not support specified dc versions", )), @@ -228,6 +229,7 @@ fn dc_mtls_handshake_client_not_supported_self_test() -> Result<()> { self_test( server, client, + false, Some(expected_client_error), Some(connection::Error::invalid_configuration( "peer does not support specified dc versions", @@ -266,7 +268,7 @@ fn dc_possible_secret_control_packet( .with_dc(dc_endpoint)? .with_packet_interceptor(RandomShort::default())?; - let (client_events, _server_events) = self_test(server, client, None, None)?; + let (client_events, _server_events) = self_test(server, client, true, None, None)?; assert_eq!( 1, @@ -297,6 +299,7 @@ fn dc_possible_secret_control_packet( fn self_test( server: server::Builder, client: client::Builder, + client_has_dc: bool, expected_client_error: Option, expected_server_error: Option, ) -> Result<(DcRecorder, DcRecorder)> { @@ -318,18 +321,21 @@ fn self_test( let addr = server.local_addr()?; + let expected_count = 1 + client_has_dc as usize; spawn(async move { - if let Some(mut conn) = server.accept().await { - let result = dc::ConfirmComplete::wait_ready(&mut conn).await; - - if let Some(error) = expected_server_error { - assert_eq!(error, convert_io_result(result).unwrap()); - - if expected_client_error.is_some() { - conn.close(SERVER_CLOSE_ERROR_CODE.into()); + for _ in 0..expected_count { + if let Some(mut conn) = server.accept().await { + let result = dc::ConfirmComplete::wait_ready(&mut conn).await; + + if let Some(error) = expected_server_error { + assert_eq!(error, convert_io_result(result).unwrap()); + + if expected_client_error.is_some() { + conn.close(SERVER_CLOSE_ERROR_CODE.into()); + } + } else { + assert!(result.is_ok()); } - } else { - assert!(result.is_ok()); } } }); @@ -340,35 +346,41 @@ fn self_test( .with_random(Random::with_seed(456))? .start()?; - let client_events = client_events.clone(); - - primary::spawn(async move { - let connect = Connect::new(addr).with_server_name("localhost"); - let mut conn = client.connect(connect).await.unwrap(); - let result = dc::ConfirmComplete::wait_ready(&mut conn).await; - - if let Some(error) = expected_client_error { - assert_eq!(error, convert_io_result(result).unwrap()); - - if expected_server_error.is_some() { - conn.close(CLIENT_CLOSE_ERROR_CODE.into()); - // wait for the server to assert the expected error before dropping - delay(Duration::from_millis(100)).await; + for _ in 0..expected_count { + primary::spawn({ + let client = client.clone(); + let client_events = client_events.clone(); + async move { + let connect = Connect::new(addr) + .with_server_name("localhost") + .with_deduplicate(client_has_dc); + let mut conn = client.connect(connect).await.unwrap(); + let result = dc::ConfirmComplete::wait_ready(&mut conn).await; + + if let Some(error) = expected_client_error { + assert_eq!(error, convert_io_result(result).unwrap()); + + if expected_server_error.is_some() { + conn.close(CLIENT_CLOSE_ERROR_CODE.into()); + // wait for the server to assert the expected error before dropping + delay(Duration::from_millis(100)).await; + } + } else { + assert!(result.is_ok()); + let client_events = client_events + .dc_state_changed_events() + .lock() + .unwrap() + .clone(); + assert_dc_complete(&client_events); + // wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent + // before the client closes the connection. This is only necessary to confirm the `dc::State` + // on the server moves to `DcState::Complete` + delay(Duration::from_millis(100)).await; + } } - } else { - assert!(result.is_ok()); - let client_events = client_events - .dc_state_changed_events() - .lock() - .unwrap() - .clone(); - assert_dc_complete(&client_events); - // wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent - // before the client closes the connection. This is only necessary to confirm the `dc::State` - // on the server moves to `DcState::Complete` - delay(Duration::from_millis(100)).await; - } - }); + }); + } Ok(addr) })