Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(s2n-quic-dc): wait to insert in peer map until handshake completes #2358

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions dc/s2n-quic-dc/src/path/secret/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,19 +488,21 @@ impl Map {
Some(state.clone())
}

pub(super) fn insert(&self, entry: Arc<Entry>) {
pub(super) fn on_new_path_secrets(&self, entry: Arc<Entry>) {
// On insert clear our interest in a handshake.
self.state.requested_handshakes.pin().remove(&entry.peer);
let id = *entry.secret.id();
let peer = entry.peer;
if self.state.ids.insert(id, entry.clone()).is_some() {
if self.state.ids.insert(*entry.secret.id(), entry).is_some() {
// FIXME: Make insertion fallible and fail handshakes instead?
panic!("inserting a path secret ID twice");
}
}

pub(super) fn on_handshake_complete(&self, entry: Arc<Entry>) {
let id = *entry.secret.id();

if let Some(prev) = self.state.peers.insert(peer, entry) {
// This shouldn't happen due to the panic above, but just in case something went wrong
// with the secret map we double check here.
if let Some(prev) = self.state.peers.insert(entry.peer, entry) {
// This shouldn't happen due to the panic in on_new_path_secrets, but just
// in case something went wrong with the secret map we double check here.
// FIXME: Make insertion fallible and fail handshakes instead?
assert_ne!(*prev.secret.id(), id, "duplicate path secret id");

Expand Down Expand Up @@ -546,7 +548,8 @@ impl Map {
dc::testing::TEST_REHANDSHAKE_PERIOD,
);
let entry = Arc::new(entry);
provider.insert(entry);
provider.on_new_path_secrets(entry.clone());
provider.on_handshake_complete(entry);
}

(provider, ids)
Expand All @@ -573,7 +576,9 @@ impl Map {
dc::testing::TEST_APPLICATION_PARAMS,
dc::testing::TEST_REHANDSHAKE_PERIOD,
);
self.insert(Arc::new(entry));
let entry = Arc::new(entry);
self.on_new_path_secrets(entry.clone());
self.on_handshake_complete(entry);
}

fn send_control(&self, entry: &Entry, credentials: &Credentials, error: receiver::Error) {
Expand Down Expand Up @@ -1057,7 +1062,15 @@ impl dc::Path for HandshakingPath {
);
let entry = Arc::new(entry);
self.entry = Some(entry.clone());
self.map.insert(entry);
self.map.on_new_path_secrets(entry);
}

fn on_dc_handshake_complete(&mut self) {
let entry = self.entry.clone().expect(
"the dc handshake cannot be complete without \
on_peer_stateless_reset_tokens creating a map entry",
);
self.map.on_handshake_complete(entry);
}

fn on_mtu_updated(&mut self, mtu: u16) {
Expand Down
35 changes: 26 additions & 9 deletions dc/s2n-quic-dc/src/path/secret/map/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@ fn cleans_after_delay() {
let first = fake_entry(1);
let second = fake_entry(1);
let third = fake_entry(1);
map.insert(first.clone());
map.insert(second.clone());
map.on_new_path_secrets(first.clone());
map.on_handshake_complete(first.clone());
map.on_new_path_secrets(second.clone());
map.on_handshake_complete(second.clone());

assert!(map.state.ids.contains_key(first.secret.id()));
assert!(map.state.ids.contains_key(second.secret.id()));

map.state.cleaner.clean(&map.state, 1);
map.state.cleaner.clean(&map.state, 1);

map.insert(third.clone());
map.on_new_path_secrets(third.clone());
map.on_handshake_complete(third.clone());

assert!(!map.state.ids.contains_key(first.secret.id()));
assert!(map.state.ids.contains_key(second.secret.id()));
Expand Down Expand Up @@ -86,9 +89,10 @@ struct Model {

#[derive(bolero::TypeGenerator, Debug, Copy, Clone)]
enum Operation {
Insert { ip: u8, path_secret_id: TestId },
NewPathSecret { ip: u8, path_secret_id: TestId },
AdvanceTime,
ReceiveUnknown { path_secret_id: TestId },
HandshakeComplete { path_secret_id: TestId },
}

#[derive(bolero::TypeGenerator, PartialEq, Eq, Hash, Copy, Clone)]
Expand Down Expand Up @@ -130,13 +134,13 @@ enum Invariant {
impl Model {
fn perform(&mut self, operation: Operation, state: &Map) {
match operation {
Operation::Insert { ip, path_secret_id } => {
Operation::NewPathSecret { ip, path_secret_id } => {
let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([0, 0, 0, ip]), 0));
let secret = path_secret_id.secret();
let id = *secret.id();

let stateless_reset = state.state.signer.sign(&id);
state.insert(Arc::new(Entry::new(
state.on_new_path_secrets(Arc::new(Entry::new(
ip,
secret,
sender::State::new(stateless_reset),
Expand All @@ -145,9 +149,16 @@ impl Model {
dc::testing::TEST_REHANDSHAKE_PERIOD,
)));

self.invariants.insert(Invariant::ContainsIp(ip));
self.invariants.insert(Invariant::ContainsId(id));
}
Operation::HandshakeComplete { path_secret_id } => {
if let Some(entry) = state.state.ids.get_by_key(&path_secret_id.id()) {
if !state.state.peers.contains_key(&entry.peer) {
state.on_handshake_complete(entry.clone());
}
self.invariants.insert(Invariant::ContainsIp(entry.peer));
}
}
Operation::AdvanceTime => {
let mut invalidated = Vec::new();
self.invariants.retain(|invariant| {
Expand Down Expand Up @@ -232,7 +243,7 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool {
let mut ids = HashSet::new();
for op in ops.iter() {
match op {
Operation::Insert {
Operation::NewPathSecret {
ip: _,
path_secret_id,
} => {
Expand All @@ -244,6 +255,10 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool {
Operation::ReceiveUnknown { path_secret_id: _ } => {
// no-op, we're fine receiving unknown pids.
}
Operation::HandshakeComplete { .. } => {
// no-op, a handshake complete for the same pid as a
// new path secret is expected
}
}
}

Expand Down Expand Up @@ -320,7 +335,9 @@ fn no_memory_growth() {
map.state.cleaner.stop();
for idx in 0..500_000 {
// FIXME: this ends up 2**16 peers in the `peers` map
map.insert(fake_entry(idx as u16));
let entry = fake_entry(idx as u16);
map.on_new_path_secrets(entry.clone());
map.on_handshake_complete(entry)
}
}

Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-core/src/dc/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl Path for () {
unimplemented!()
}

fn on_dc_handshake_complete(&mut self) {
unimplemented!()
}

fn on_mtu_updated(&mut self, _mtu: u16) {
unimplemented!()
}
Expand Down
8 changes: 8 additions & 0 deletions quic/s2n-quic-core/src/dc/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl MockDcEndpoint {
pub struct MockDcPath {
pub on_path_secrets_ready_count: u8,
pub on_peer_stateless_reset_tokens_count: u8,
pub on_dc_handshake_complete: u8,
pub stateless_reset_tokens: Vec<stateless_reset::Token>,
pub peer_stateless_reset_tokens: Vec<stateless_reset::Token>,
pub mtu: u16,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl dc::Path for MockDcPath {
&mut self,
_session: &impl TlsSession,
) -> Result<Vec<stateless_reset::Token>, transport::Error> {
debug_assert_eq!(0, self.on_path_secrets_ready_count);
self.on_path_secrets_ready_count += 1;
Ok(self.stateless_reset_tokens.clone())
}
Expand All @@ -77,11 +79,17 @@ impl dc::Path for MockDcPath {
&mut self,
stateless_reset_tokens: impl Iterator<Item = &'a stateless_reset::Token>,
) {
debug_assert_eq!(0, self.on_peer_stateless_reset_tokens_count);
self.on_peer_stateless_reset_tokens_count += 1;
self.peer_stateless_reset_tokens
.extend(stateless_reset_tokens);
}

fn on_dc_handshake_complete(&mut self) {
debug_assert_eq!(0, self.on_dc_handshake_complete);
self.on_dc_handshake_complete += 1;
}

fn on_mtu_updated(&mut self, mtu: u16) {
self.mtu = mtu
}
Expand Down
12 changes: 12 additions & 0 deletions quic/s2n-quic-core/src/dc/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub trait Path: 'static + Send {
stateless_reset_tokens: impl Iterator<Item = &'a stateless_reset::Token>,
);

/// Called when the peer has confirmed receipt of `DC_STATELESS_RESET_TOKENS`, either
/// by the server sending back its own `DC_STATELESS_RESET_TOKENS` or by the client
/// acknowledging the `DC_STATELESS_RESET_TOKENS` frame was received.
fn on_dc_handshake_complete(&mut self);

/// Called when the MTU has been updated for the path
fn on_mtu_updated(&mut self, mtu: u16);
}
Expand Down Expand Up @@ -73,6 +78,13 @@ impl<P: Path> Path for Option<P> {
}
}

#[inline]
fn on_dc_handshake_complete(&mut self) {
if let Some(path) = self {
path.on_dc_handshake_complete()
}
}

#[inline]
fn on_mtu_updated(&mut self, max_datagram_size: u16) {
if let Some(path) = self {
Expand Down
2 changes: 2 additions & 0 deletions quic/s2n-quic-transport/src/dc/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ impl<Config: endpoint::Config> Manager<Config> {
if Config::ENDPOINT_TYPE.is_server() {
self.stateless_reset_token_sync.send();
} else {
self.path.on_dc_handshake_complete();
publisher.on_dc_state_changed(DcStateChanged {
state: DcState::Complete,
});
Expand All @@ -176,6 +177,7 @@ impl<Config: endpoint::Config> Manager<Config> {
ensure!(self.state.on_stateless_reset_tokens_acked().is_ok());

debug_assert!(Config::ENDPOINT_TYPE.is_server());
self.path.on_dc_handshake_complete();
publisher.on_dc_state_changed(DcStateChanged {
state: DcState::Complete,
});
Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-transport/src/dc/manager/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ fn on_peer_dc_stateless_reset_tokens<Config, Endpoint>(

if Config::ENDPOINT_TYPE.is_server() {
assert!(manager.state.is_server_tokens_sent());
assert_eq!(0, manager.path().on_dc_handshake_complete);
} else {
assert_eq!(1, manager.path().on_dc_handshake_complete);
assert!(manager.state.is_complete());
}

Expand All @@ -169,6 +171,7 @@ fn on_packet_ack_client() {

// Client completes when it has received stateless reset tokens from the peer
assert!(!manager.state.is_complete());
assert_eq!(0, manager.path().on_dc_handshake_complete);
}

#[test]
Expand All @@ -182,6 +185,7 @@ fn on_packet_ack_server() {

// Server completes when its stateless reset tokens are acked
assert!(manager.state.is_complete());
assert_eq!(1, manager.path().on_dc_handshake_complete);
}

fn on_packet_ack<Config, Endpoint>(
Expand Down
Loading