diff --git a/.github/config/cargo-deny.toml b/.github/config/cargo-deny.toml index e5ef022bbe..c358aff429 100644 --- a/.github/config/cargo-deny.toml +++ b/.github/config/cargo-deny.toml @@ -1,7 +1,4 @@ [advisories] -vulnerability = "deny" -unmaintained = "deny" -notice = "deny" yanked = "deny" ignore = [ "RUSTSEC-2021-0139", # criterion, structopt, and tracing-subscriber (test dependencies) use ansi_term @@ -15,7 +12,6 @@ skip-tree = [ { name = "cuckoofilter" }, # all of these are going to be just test dependencies - { name = "aes-gcm" }, { name = "bach" }, { name = "bolero" }, { name = "criterion" }, @@ -30,16 +26,10 @@ skip-tree = [ ] [sources] -allow-git = [ - "https://github.com/camshaft/aya", # TODO: Remove once aya supports XdpMaps - https://github.com/aya-rs/aya/pull/527 -] unknown-registry = "deny" unknown-git = "deny" [licenses] -unlicensed = "deny" -allow-osi-fsf-free = "neither" -copyleft = "deny" confidence-threshold = 0.9 # ignore licenses for private crates private = { ignore = true } @@ -47,6 +37,7 @@ allow = [ "Apache-2.0", "BSD-2-Clause", "BSD-3-Clause", + "CC0-1.0", "ISC", "MIT", "OpenSSL", diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index 66444db0bb..4c9f667379 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -21,17 +21,20 @@ bolero-generator = { version = "0.11", optional = true } bytes = "1" crossbeam-channel = "0.5" crossbeam-queue = { version = "0.3" } +flurry = "0.5" libc = "0.2" num-rational = { version = "0.4", default-features = false } once_cell = "1" +rand = { version = "0.8", features = ["small_rng"] } s2n-codec = { version = "=0.40.0", path = "../../common/s2n-codec", default-features = false } s2n-quic-core = { version = "=0.40.0", path = "../../quic/s2n-quic-core", default-features = false } s2n-quic-platform = { version = "=0.40.0", path = "../../quic/s2n-quic-platform" } slotmap = "1" thiserror = "1" -tokio = { version = "1", features = ["io-util"], optional = true } +tokio = { version = "1", features = ["sync"] } tracing = "0.1" zerocopy = { version = "0.7", features = ["derive"] } +zeroize = "1" [dev-dependencies] bolero = "0.11" @@ -39,4 +42,4 @@ bolero-generator = "0.11" insta = "1" s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } s2n-quic-core = { path = "../../quic/s2n-quic-core", features = ["testing"] } -tokio = { version = "1", features = ["io-util"] } +tokio = { version = "1", features = ["sync"] } diff --git a/dc/s2n-quic-dc/src/datagram.rs b/dc/s2n-quic-dc/src/datagram.rs new file mode 100644 index 0000000000..2b866fc0bc --- /dev/null +++ b/dc/s2n-quic-dc/src/datagram.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod tunneled; diff --git a/dc/s2n-quic-dc/src/datagram/tunneled.rs b/dc/s2n-quic-dc/src/datagram/tunneled.rs new file mode 100644 index 0000000000..8b3c230f65 --- /dev/null +++ b/dc/s2n-quic-dc/src/datagram/tunneled.rs @@ -0,0 +1,8 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod recv; +pub mod send; + +pub use recv::Receiver; +pub use send::Sender; diff --git a/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs new file mode 100644 index 0000000000..b57cc8adef --- /dev/null +++ b/dc/s2n-quic-dc/src/datagram/tunneled/recv.rs @@ -0,0 +1,80 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::{decrypt, UninitSlice}, + packet::datagram::{decoder, Tag}, +}; +use s2n_codec::{decoder_invariant, DecoderBufferMut, DecoderError}; +use s2n_quic_core::packet::number::{PacketNumberSpace, SlidingWindow, SlidingWindowError}; + +pub use crate::crypto::decrypt::Error; +pub use decoder::Packet; + +#[derive(Default)] +pub struct Endpoint {} + +impl Endpoint { + pub fn parse<'a>(&self, payload: &'a mut [u8]) -> Option<(Packet<'a>, &'a mut [u8])> { + let buffer = DecoderBufferMut::new(payload); + let (packet, buffer) = Packet::decode(buffer, TagValidator, 16).ok()?; + let buffer = buffer.into_less_safe_slice(); + Some((packet, buffer)) + } +} + +struct TagValidator; + +impl decoder::Validator for TagValidator { + #[inline] + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError> { + decoder_invariant!(!tag.ack_eliciting(), "expected tunnelled datagram"); + decoder_invariant!( + !tag.has_application_header(), + "application headers currently unsupported" + ); + Ok(()) + } +} + +pub struct Receiver { + key: K, +} + +impl Receiver { + pub fn new(key: K) -> Self { + Self { key } + } + + pub fn recv_into( + &mut self, + packet: &Packet, + payload_out: &mut UninitSlice, + ) -> Result<(), Error> { + debug_assert_eq!(packet.payload().len(), payload_out.len()); + + self.key.decrypt( + packet.crypto_nonce(), + packet.header(), + packet.payload(), + packet.auth_tag(), + payload_out, + )?; + + Ok(()) + } +} + +#[derive(Default)] +pub struct SeenFilter { + window: SlidingWindow, +} + +impl SeenFilter { + #[inline] + pub fn on_packet(&mut self, packet: &Packet) -> Result<(), SlidingWindowError> { + let packet_number = + PacketNumberSpace::ApplicationData.new_packet_number(packet.packet_number()); + self.window.insert(packet_number) + } +} diff --git a/dc/s2n-quic-dc/src/datagram/tunneled/send.rs b/dc/s2n-quic-dc/src/datagram/tunneled/send.rs new file mode 100644 index 0000000000..5524e76c7f --- /dev/null +++ b/dc/s2n-quic-dc/src/datagram/tunneled/send.rs @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + control, + crypto::encrypt, + packet::{self, datagram::encoder}, +}; +use core::sync::atomic::{AtomicU64, Ordering}; +use s2n_codec::EncoderBuffer; +use s2n_quic_core::{ensure, varint::VarInt}; + +#[derive(Clone, Copy, Debug)] +pub enum Error { + PayloadTooLarge, + PacketBufferTooSmall, + PacketNumberExhaustion, +} + +pub struct Sender { + encrypt_key: E, + packet_number: AtomicU64, +} + +impl Sender +where + E: encrypt::Key, +{ + #[inline] + pub fn new(encrypt_key: E) -> Self { + Self { + encrypt_key, + packet_number: AtomicU64::new(0), + } + } + + #[inline] + pub fn estimated_send_size(&self, cleartext_payload_len: usize) -> Option { + let payload_len = packet::PayloadLen::try_from(cleartext_payload_len).ok()?; + Some(encoder::estimate_len( + VarInt::ZERO, + None, + VarInt::ZERO, + payload_len, + E::tag_len(&self.encrypt_key), + )) + } + + #[inline] + pub fn send_into( + &self, + control_port: &C, + mut cleartext_payload: &[u8], + encrypted_packet: &mut [u8], + ) -> Result + where + C: control::Controller, + { + let packet_number = self.packet_number.fetch_add(1, Ordering::Relaxed); + let packet_number = + VarInt::new(packet_number).map_err(|_| Error::PacketNumberExhaustion)?; + + let payload_len = packet::PayloadLen::try_from(cleartext_payload.len()) + .map_err(|_| Error::PayloadTooLarge)?; + + let estimated_packet_len = self + .estimated_send_size(cleartext_payload.len()) + .ok_or(Error::PayloadTooLarge)?; + + // ensure the descriptor has enough capacity after MTU/allocation + ensure!( + encrypted_packet.len() >= estimated_packet_len, + Err(Error::PacketBufferTooSmall) + ); + + let actual_packet_len = { + let source_control_port = control_port.source_port(); + + let encoder = EncoderBuffer::new(encrypted_packet); + + encoder::encode( + encoder, + source_control_port, + Some(packet_number), + None, + VarInt::ZERO, + &mut &[][..], + &(), + payload_len, + &mut cleartext_payload, + &self.encrypt_key, + ) + }; + + Ok(actual_packet_len) + } +} diff --git a/dc/s2n-quic-dc/src/lib.rs b/dc/s2n-quic-dc/src/lib.rs index a361fd8b89..e98c61d057 100644 --- a/dc/s2n-quic-dc/src/lib.rs +++ b/dc/s2n-quic-dc/src/lib.rs @@ -6,6 +6,7 @@ pub mod congestion; pub mod control; pub mod credentials; pub mod crypto; +pub mod datagram; pub mod msg; pub mod packet; pub mod path; @@ -13,3 +14,5 @@ pub mod pool; pub mod recovery; pub mod socket; pub mod stream; + +pub use s2n_quic_core::dc::{Version, SUPPORTED_VERSIONS}; diff --git a/dc/s2n-quic-dc/src/packet/reset.rs b/dc/s2n-quic-dc/src/packet/reset.rs deleted file mode 100644 index 5d8d63c992..0000000000 --- a/dc/s2n-quic-dc/src/packet/reset.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use core::{ - fmt, - ops::{Deref, DerefMut}, -}; -use s2n_codec::zerocopy_value_codec; -use s2n_quic_core::varint::VarInt; -use zerocopy::{AsBytes, FromBytes, Unaligned}; - -#[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, Unaligned)] -#[repr(C)] -pub struct Tag(u8); - -zerocopy_value_codec!(Tag); - -impl Default for Tag { - #[inline] - fn default() -> Self { - Self(0b0110_0000) - } -} - -/* -impl fmt::Debug for Tag { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("datagram::Tag") - .field("mode", &self.mode()) - .finish() - } -} - -impl Tag { - #[inline] - pub fn mode(&self) -> Mode { - - } -} - -#[derive(Clone, Copy, Debug)] -pub enum Mode { - Early, - Authenticated, - Stateless, -} -*/ diff --git a/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs b/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs deleted file mode 100644 index 088da0b3e7..0000000000 --- a/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use super::*; - -impl_tag!(REQUEST_ADDITIONAL_GENERATION); -impl_packet!(RequestAdditionalGeneration); - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(test, derive(bolero_generator::TypeGenerator))] -pub struct RequestAdditionalGeneration { - pub credential_id: credentials::Id, - pub generation_id: u32, -} - -impl RequestAdditionalGeneration { - #[inline] - pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize - where - C: encrypt::Key, - { - let generation_id = self.generation_id; - - encoder.encode(&Tag::default()); - encoder.encode(&&self.credential_id[..]); - encoder.encode(&VarInt::from(generation_id)); - - encoder::finish( - encoder, - Nonce::RequestAdditionalGeneration { generation_id }, - crypto, - ) - } - - #[inline] - pub fn nonce(&self) -> Nonce { - Nonce::RequestAdditionalGeneration { - generation_id: self.generation_id, - } - } - - #[cfg(test)] - fn validate(&self) -> Option<()> { - Some(()) - } -} - -impl<'a> DecoderValue<'a> for RequestAdditionalGeneration { - #[inline] - fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { - let (tag, buffer) = buffer.decode::()?; - decoder_invariant!(tag == Tag::default(), "invalid tag"); - let (credential_id, buffer) = buffer.decode()?; - let (generation_id, buffer) = decoder::sized(buffer)?; - let value = Self { - credential_id, - generation_id, - }; - Ok((value, buffer)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - impl_tests!(RequestAdditionalGeneration); -} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs b/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs deleted file mode 100644 index d24758e7c6..0000000000 --- a/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use super::*; - -impl_tag!(REQUEST_SHARDS); -impl_packet!(RequestShards); - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(test, derive(bolero_generator::TypeGenerator))] -pub struct RequestShards { - pub credential_id: credentials::Id, - pub receiving_shards: u16, - pub shard_width: u64, -} - -impl RequestShards { - #[inline] - pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize - where - C: encrypt::Key, - { - encoder.encode(&Tag::default()); - encoder.encode(&self.credential_id); - encoder.encode(&VarInt::from(self.receiving_shards)); - encoder.encode(&self.shard_width); - - encoder::finish(encoder, self.nonce(), crypto) - } - - #[inline] - pub fn nonce(&self) -> Nonce { - Nonce::RequestShards { - receiving_shards: self.receiving_shards, - shard_width: self.shard_width, - } - } - - #[cfg(test)] - fn validate(&self) -> Option<()> { - Some(()) - } -} - -impl<'a> DecoderValue<'a> for RequestShards { - #[inline] - fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { - let (tag, buffer) = buffer.decode::()?; - decoder_invariant!(tag == Tag::default(), "invalid tag"); - let (credential_id, buffer) = buffer.decode()?; - let (receiving_shards, buffer) = buffer.decode::()?; - let (shard_width, buffer) = buffer.decode()?; - let value = Self { - credential_id, - receiving_shards: receiving_shards - .try_into() - .map_err(|_| DecoderError::InvariantViolation("receiving_shards too big"))?, - shard_width, - }; - Ok((value, buffer)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - impl_tests!(RequestShards); -} diff --git a/dc/s2n-quic-dc/src/path.rs b/dc/s2n-quic-dc/src/path.rs index 3485d5eed3..9a98ad630b 100644 --- a/dc/s2n-quic-dc/src/path.rs +++ b/dc/s2n-quic-dc/src/path.rs @@ -1,13 +1,16 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use core::time::Duration; use s2n_quic_core::{ path::{Handle, MaxMtu, Tuple}, varint::VarInt, }; -static DEFAULT_MAX_DATA: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { +pub mod secret; +#[cfg(any(test, feature = "testing"))] +pub mod testing; + +pub static DEFAULT_MAX_DATA: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { std::env::var("DC_QUIC_DEFAULT_MAX_DATA") .ok() .and_then(|v| v.parse().ok()) @@ -15,7 +18,7 @@ static DEFAULT_MAX_DATA: once_cell::sync::Lazy = once_cell::sync::Lazy:: .into() }); -static DEFAULT_MTU: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { +pub static DEFAULT_MTU: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { let default_mtu = if cfg!(target_os = "linux") { 8940 } else { @@ -30,7 +33,7 @@ static DEFAULT_MTU: once_cell::sync::Lazy = once_cell::sync::Lazy::new(| .unwrap() }); -static DEFAULT_IDLE_TIMEOUT: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { +pub static DEFAULT_IDLE_TIMEOUT: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { std::env::var("DC_QUIC_DEFAULT_IDLE_TIMEOUT") .ok() .and_then(|v| v.parse().ok()) @@ -39,9 +42,6 @@ static DEFAULT_IDLE_TIMEOUT: once_cell::sync::Lazy = once_cell::sync::Lazy: .unwrap() }); -#[cfg(any(test, feature = "testing"))] -pub mod testing; - pub trait Controller { type Handle: Handle; @@ -56,31 +56,3 @@ impl Controller for Tuple { self } } - -#[derive(Clone, Copy, Debug)] -pub struct Parameters { - pub max_mtu: MaxMtu, - pub remote_max_data: VarInt, - pub local_send_max_data: VarInt, - pub local_recv_max_data: VarInt, - pub idle_timeout_secs: u32, -} - -impl Default for Parameters { - fn default() -> Self { - Self { - max_mtu: *DEFAULT_MTU, - remote_max_data: *DEFAULT_MAX_DATA, - local_send_max_data: *DEFAULT_MAX_DATA, - local_recv_max_data: *DEFAULT_MAX_DATA, - idle_timeout_secs: *DEFAULT_IDLE_TIMEOUT, - } - } -} - -impl Parameters { - #[inline] - pub fn idle_timeout(&self) -> Duration { - Duration::from_secs(self.idle_timeout_secs as _) - } -} diff --git a/dc/s2n-quic-dc/src/path/secret.rs b/dc/s2n-quic-dc/src/path/secret.rs new file mode 100644 index 0000000000..fa1edcdb75 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret.rs @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod key; +pub mod map; +#[doc(hidden)] +pub mod receiver; +#[doc(hidden)] +pub mod schedule; +mod sender; +pub mod stateless_reset; + +pub use key::{Opener, Sealer}; +pub use map::Map; diff --git a/dc/s2n-quic-dc/src/path/secret/key.rs b/dc/s2n-quic-dc/src/path/secret/key.rs new file mode 100644 index 0000000000..f85f330489 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/key.rs @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::map; +use crate::{ + credentials::Credentials, + crypto::{awslc, decrypt, encrypt, IntoNonce, UninitSlice}, +}; +use core::mem::MaybeUninit; +use zeroize::Zeroize; + +#[derive(Debug)] +pub struct Sealer { + pub(super) sealer: awslc::EncryptKey, +} + +impl encrypt::Key for Sealer { + #[inline] + fn credentials(&self) -> &Credentials { + self.sealer.credentials() + } + + #[inline] + fn tag_len(&self) -> usize { + self.sealer.tag_len() + } + + #[inline] + fn encrypt( + &self, + nonce: N, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + self.sealer + .encrypt(nonce, header, extra_payload, payload_and_tag) + } + + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + self.sealer.retransmission_tag( + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} + +#[derive(Debug)] +pub struct Opener { + pub(super) opener: awslc::DecryptKey, + pub(super) dedup: map::Dedup, +} + +impl Opener { + /// Disables replay prevention allowing the decryption key to be reused. + /// + /// ## Safety + /// Disabling replay prevention is insecure because it makes it possible for + /// active network attackers to cause a peer to accept previously processed + /// data as new. For example, if a packet contains a mutating request such + /// as adding +1 to a value in a database, an attacker can keep replaying + /// packets to increment the value beyond what the original legitimate + /// sender of the packet intended. + pub unsafe fn disable_replay_prevention(&mut self) { + self.dedup.disable(); + } + + /// Ensures the key has not been used before + #[inline] + fn on_decrypt_success(&self, payload: &mut UninitSlice) -> decrypt::Result { + self.dedup.check(&self.opener).map_err(|e| { + let payload = unsafe { + let ptr = payload.as_mut_ptr() as *mut MaybeUninit; + let len = payload.len(); + core::slice::from_raw_parts_mut(ptr, len) + }; + payload.zeroize(); + e + })?; + + Ok(()) + } +} + +impl decrypt::Key for Opener { + #[inline] + fn credentials(&self) -> &Credentials { + self.opener.credentials() + } + + #[inline] + fn tag_len(&self) -> usize { + self.opener.tag_len() + } + + #[inline] + fn decrypt( + &self, + nonce: N, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut UninitSlice, + ) -> decrypt::Result { + self.opener + .decrypt(nonce, header, payload_in, tag, payload_out)?; + + self.on_decrypt_success(payload_out)?; + + Ok(()) + } + + #[inline] + fn decrypt_in_place( + &self, + nonce: N, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> decrypt::Result { + self.opener + .decrypt_in_place(nonce, header, payload_and_tag)?; + + self.on_decrypt_success(UninitSlice::new(payload_and_tag))?; + + Ok(()) + } + + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + self.opener.retransmission_tag( + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs new file mode 100644 index 0000000000..22cba586f6 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -0,0 +1,865 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + receiver, + schedule::{self, Initiator}, + sender, stateless_reset, Opener, Sealer, +}; +use crate::{ + credentials::{Credentials, Id}, + crypto, + packet::{secret_control as control, Packet}, +}; +use rand::Rng as _; +use s2n_codec::EncoderBuffer; +use s2n_quic_core::{ + dc::{self, ApplicationParams}, + event::api::EndpointType, +}; +use std::{ + fmt, + net::{Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; +use zeroize::Zeroizing; + +const TLS_EXPORTER_LABEL: &str = "EXPERIMENTAL EXPORTER s2n-quic-dc"; +const TLS_EXPORTER_CONTEXT: &str = ""; +const TLS_EXPORTER_LENGTH: usize = schedule::EXPORT_SECRET_LEN; + +// FIXME: Most of this comment is not true today, we're expecting to implement the details +// contained here. This is presented as a roadmap. +/// This map caches path secrets derived from handshakes. +/// +/// The cache is configurable on two axes: +/// +/// * Maximum size (in megabytes) +/// * Maximum per-peer/secret derivation per-second rate (in derived secrets, e.g., accepted/opened streams) +/// +/// Each entry in the cache will take around 550 bytes plus 15 bits per derived secret at the +/// maximum rate (corresponding to no false positives in replay prevention for 15 seconds). +#[derive(Clone)] +pub struct Map { + pub(super) state: Arc, +} + +// # Managing memory consumption +// +// For regular rotation with live peers, we retain at most two secrets: one derived from the most +// recent locally initiated handshake and the most recent remote initiated handshake (from our +// perspective). We guarantee that at most one handshake is ongoing for a given peer pair at a +// time, so both sides will have at least one mutually trusted entry after the handshake. If a peer +// is only acting as a client or only as a server, then one of the peer maps will always be empty. +// +// Previous entries can safely be removed after a grace period (EVICTION_TIME). EVICTION_TIME +// is only needed because a stream/datagram might be opening/sent concurrently with the new +// handshake (e.g., during regular rotation), and we don't want that to fail spuriously. +// +// We also need to manage secrets for no longer existing peers. These are peers where typically the +// underlying host has gone away and/or the address for it has changed. At 95% occupancy for the +// maximum size allowed, we will remove least recently used secrets (1% of these per minute). Usage +// is defined by access to the entry in the map. Unfortunately we lack any good way to authenticate +// a peer as *not* having credentials, especially after the peer is gone. It's possible that in the +// future information could also come from the TLS provider. +pub(super) struct State { + // This is in number of entries. + max_capacity: usize, + + // peers is the most recent entry originating from a locally *or* remote initiated handshake. + // + // Handshakes use s2n-quic and the SocketAddr is the address of the handshake socket. Since + // s2n-quic only has Client or Server endpoints, a given SocketAddr can only be used for + // exactly one of a locally initiated handshake or a remote initiated handshake. As a result we + // can use a single map to store both kinds and treat them identically. + // + // In the future it's likely we'll want to build bidirectional support in which case splitting + // this into two maps (per the discussino in "Managing memory consumption" above) will be + // needed. + pub(super) peers: flurry::HashMap>, + + // This is used for deduplicating outgoing handshakes. We manage this here as it's a + // property required for correctness (see comment on the struct). + // + // FIXME: make use of this. + #[allow(unused)] + pub(super) ongoing_handshakes: flurry::HashMap, + + // Stores the set of SocketAddr for which we received a UnknownPathSecret packet. + // When handshake_with is called we will allow a new handshake if this contains a socket, this + // is a temporary solution until we implement proper background handshaking. + pub(super) requested_handshakes: flurry::HashSet, + + // All known entries. + pub(super) ids: flurry::HashMap>, + + pub(super) signer: stateless_reset::Signer, + + // This socket is used *only* for sending secret control packets. + // FIXME: This will get replaced with sending on a handshake socket associated with the map. + pub(super) control_socket: std::net::UdpSocket, + + pub(super) receiver_shared: Arc, + + handled_control_packets: AtomicUsize, + + cleaner: Cleaner, +} + +struct Cleaner { + should_stop: AtomicBool, + thread: Mutex>>, + epoch: AtomicU64, +} + +impl Drop for Cleaner { + fn drop(&mut self) { + self.stop(); + } +} + +impl Cleaner { + fn new() -> Cleaner { + Cleaner { + should_stop: AtomicBool::new(false), + thread: Mutex::new(None), + epoch: AtomicU64::new(1), + } + } + + fn stop(&self) { + self.should_stop.store(true, Ordering::Relaxed); + if let Some(thread) = + std::mem::take(&mut *self.thread.lock().unwrap_or_else(|e| e.into_inner())) + { + thread.thread().unpark(); + + // If this isn't getting dropped on the cleaner thread, + // then wait for the background thread to finish exiting. + if std::thread::current().id() != thread.thread().id() { + // We expect this to terminate very quickly. + thread.join().unwrap(); + } + } + } + + fn spawn_thread(&self, state: Arc) { + let state = Arc::downgrade(&state); + let handle = std::thread::spawn(move || loop { + let Some(state) = state.upgrade() else { + break; + }; + if state.cleaner.should_stop.load(Ordering::Relaxed) { + break; + } + state.cleaner.clean(&state, EVICTION_CYCLES); + let pause = rand::thread_rng().gen_range(5..60); + drop(state); + std::thread::park_timeout(Duration::from_secs(pause)); + }); + *self.thread.lock().unwrap() = Some(handle); + } + + /// Clean up dead items. + // In local benchmarking iterating a 500,000 element flurry::HashMap takes about + // 60-70ms. With contention, etc. it might be longer, but this is not an overly long + // time given that we expect to run this in a background thread once a minute. + // + // This is exposed as a method primarily for tests to directly invoke. + fn clean(&self, state: &State, eviction_cycles: u64) { + let current_epoch = self.epoch.fetch_add(1, Ordering::Relaxed); + + // FIXME: Rather than just tracking one minimum, we might want to try to do some counting + // as we iterate to have a higher likelihood of identifying 1% of peers falling into the + // epoch we pick. Exactly how to do that without collecting a ~full distribution by epoch + // is not clear though and we'd prefer to avoid allocating extra memory here. + // + // As-is we're just hoping that once-per-minute oldest-epoch identification and removal is + // enough that we keep the capacity below 100%. We could have a mode that starts just + // randomly evicting entries if we hit 100% but even this feels like an annoying modality + // to deal with. + let mut minimum = u64::MAX; + { + let guard = state.ids.guard(); + for (id, entry) in state.ids.iter(&guard) { + let retired_at = entry.retired.0.load(Ordering::Relaxed); + if retired_at == 0 { + // Find the minimum non-retired epoch currently in the set. + minimum = std::cmp::min(entry.used_at.load(Ordering::Relaxed), minimum); + + // Not retired. + continue; + } + // Avoid panics on overflow (which should never happen...) + if current_epoch.saturating_sub(retired_at) >= eviction_cycles { + state.ids.remove(id, &guard); + } + } + } + + if state.ids.len() <= (state.max_capacity * 95 / 100) { + return; + } + + let mut to_remove = std::cmp::max(state.ids.len() / 100, 1); + let guard = state.ids.guard(); + for (id, entry) in state.ids.iter(&guard) { + if to_remove > 0 { + // Only remove with the minimum epoch. This hopefully means that we will remove + // fairly stale entries. + if entry.used_at.load(Ordering::Relaxed) == minimum { + state.ids.remove(id, &guard); + to_remove -= 1; + } + } else { + break; + } + } + } + + fn epoch(&self) -> u64 { + self.epoch.load(Ordering::Relaxed) + } +} + +const EVICTION_CYCLES: u64 = if cfg!(test) { 0 } else { 10 }; + +impl Map { + pub fn new(signer: stateless_reset::Signer) -> Self { + // FIXME: Avoid unwrap and the whole socket. + // + // We only ever send on this socket - but we really should be sending on the same + // socket as used by an associated s2n-quic handshake runtime, and receiving control packets + // from that socket as well. Not exactly clear on how to achieve that yet though (both + // ownership wise since the map doesn't have direct access to handshakes and in terms + // of implementation). + let control_socket = std::net::UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).unwrap(); + control_socket.set_nonblocking(true).unwrap(); + let state = State { + // This is around 500MB with current entry size. + max_capacity: 500_000, + peers: Default::default(), + ongoing_handshakes: Default::default(), + requested_handshakes: Default::default(), + ids: Default::default(), + cleaner: Cleaner::new(), + signer, + + receiver_shared: receiver::Shared::new(), + + handled_control_packets: AtomicUsize::new(0), + control_socket, + }; + + let state = Arc::new(state); + + state.cleaner.spawn_thread(state.clone()); + + Self { state } + } + + pub fn drop_state(&self) { + self.state.peers.pin().clear(); + self.state.ids.pin().clear(); + } + + pub fn contains(&self, peer: SocketAddr) -> bool { + self.state.peers.pin().contains_key(&peer) + && !self.state.requested_handshakes.pin().contains(&peer) + } + + pub fn sealer(&self, peer: SocketAddr) -> Option<(Sealer, ApplicationParams)> { + let peers_guard = self.state.peers.guard(); + let state = self.state.peers.get(&peer, &peers_guard)?; + state.mark_live(self.state.cleaner.epoch()); + + let sealer = state.uni_sealer(); + Some((sealer, state.parameters)) + } + + pub fn opener(&self, credentials: &Credentials, control_out: &mut Vec) -> Option { + let state = self.pre_authentication(credentials, control_out)?; + let opener = state.uni_opener(self.clone(), credentials); + Some(opener) + } + + pub fn pair_for_peer(&self, peer: SocketAddr) -> Option<(Sealer, Opener, ApplicationParams)> { + let peers_guard = self.state.peers.guard(); + let state = self.state.peers.get(&peer, &peers_guard)?; + state.mark_live(self.state.cleaner.epoch()); + + let (sealer, opener) = state.bidi_local(); + + Some((sealer, opener, state.parameters)) + } + + pub fn pair_for_credentials( + &self, + credentials: &Credentials, + control_out: &mut Vec, + ) -> Option<(Sealer, Opener, ApplicationParams)> { + let state = self.pre_authentication(credentials, control_out)?; + + let params = state.parameters; + let (sealer, opener) = state.bidi_remote(self.clone(), credentials); + + Some((sealer, opener, params)) + } + + /// This can be called from anywhere to ask the map to handle a packet. + /// + /// For secret control packets, this will process those. + /// For other packets, the map may collect metrics but will otherwise drop the packets. + pub fn handle_unexpected_packet(&self, packet: &Packet) { + match packet { + Packet::Stream(_) => { + // no action for now. FIXME: Add metrics. + } + Packet::Datagram(_) => { + // no action for now. FIXME: Add metrics. + } + Packet::Control(_) => { + // no action for now. FIXME: Add metrics. + } + Packet::StaleKey(packet) => self.handle_control_packet(&(*packet).into()), + Packet::ReplayDetected(packet) => self.handle_control_packet(&(*packet).into()), + Packet::UnknownPathSecret(packet) => self.handle_control_packet(&(*packet).into()), + } + } + + pub fn handle_unknown_secret_packet(&self, packet: &control::unknown_path_secret::Packet) { + let ids_guard = self.state.ids.guard(); + let Some(state) = self.state.ids.get(packet.credential_id(), &ids_guard) else { + return; + }; + // Do not mark as live, this is lightly authenticated. + + // ensure the packet is authentic + if packet.authenticate(&state.sender.stateless_reset).is_none() { + return; + } + + self.state + .handled_control_packets + .fetch_add(1, Ordering::Relaxed); + + // FIXME: More actively schedule a new handshake. + // See comment on requested_handshakes for details. + self.state.requested_handshakes.pin().insert(state.peer); + } + + pub fn handle_control_packet(&self, packet: &control::Packet) { + if let control::Packet::UnknownPathSecret(ref packet) = &packet { + return self.handle_unknown_secret_packet(packet); + } + + let ids_guard = self.state.ids.guard(); + let Some(state) = self.state.ids.get(packet.credential_id(), &ids_guard) else { + // If we get a control packet we don't have a registered path secret for, ignore the + // packet. + return; + }; + + let key = state.sender.control_secret(&state.secret); + + match packet { + control::Packet::StaleKey(packet) => { + let Some(packet) = packet.authenticate(key) else { + return; + }; + state.mark_live(self.state.cleaner.epoch()); + state.sender.update_for_stale_key(packet.min_key_id); + self.state + .handled_control_packets + .fetch_add(1, Ordering::Relaxed); + } + control::Packet::ReplayDetected(packet) => { + let Some(_packet) = packet.authenticate(key) else { + return; + }; + self.state + .handled_control_packets + .fetch_add(1, Ordering::Relaxed); + + // If we see replay then we're going to assume that we should re-handshake in the + // background with this peer. Currently we can't handshake in the background (only + // in the foreground on next handshake_with). + // + // Note that there's no good way for us to prevent an attacker causing us to hit + // this code: they can always trivially replay a packet we send. At most we could + // de-duplicate *receiving* so there's one handshake per sent packet at most, but + // that's not particularly useful: we expect to send a lot of new packets that + // could be harvested. + // + // Handshaking will be rate limited per destination peer (and at least + // de-duplicated). + self.state.requested_handshakes.pin().insert(state.peer); + } + control::Packet::UnknownPathSecret(_) => unreachable!(), + } + } + + fn pre_authentication( + &self, + identity: &Credentials, + control_out: &mut Vec, + ) -> Option> { + let ids_guard = self.state.ids.guard(); + let Some(state) = self.state.ids.get(&identity.id, &ids_guard) else { + let packet = control::UnknownPathSecret { + credential_id: identity.id, + }; + control_out.resize(control::UnknownPathSecret::PACKET_SIZE, 0); + let stateless_reset = self.state.signer.sign(&identity.id); + let encoder = EncoderBuffer::new(control_out); + packet.encode(encoder, &stateless_reset); + return None; + }; + state.mark_live(self.state.cleaner.epoch()); + + match state.receiver.pre_authentication(identity) { + Ok(()) => {} + Err(e) => { + self.send_control(state, identity, e); + control_out.resize(control::UnknownPathSecret::PACKET_SIZE, 0); + + return None; + } + } + + Some(state.clone()) + } + + pub(super) fn insert(&self, entry: Arc) { + // On insert clear our interest in a handshake. + self.state.requested_handshakes.pin().remove(&entry.peer); + entry.mark_live(self.state.cleaner.epoch()); + let id = *entry.secret.id(); + let peer = entry.peer; + let ids_guard = self.state.ids.guard(); + if self + .state + .ids + .insert(id, entry.clone(), &ids_guard) + .is_some() + { + // FIXME: Make insertion fallible and fail handshakes instead? + panic!("inserting a path secret ID twice"); + } + + let peers_guard = self.state.peers.guard(); + if let Some(prev) = self.state.peers.insert(peer, entry, &peers_guard) { + // This shouldn't happen due to the panic above, but just in case something went wrong + // with the secret map we double check here. + // FIXME: Make insertion fallible and fail handshakes instead? + assert_ne!(*prev.secret.id(), id, "duplicate path secret id"); + + prev.retire(self.state.cleaner.epoch()); + } + } + + pub(super) fn signer(&self) -> &stateless_reset::Signer { + &self.state.signer + } + + #[doc(hidden)] + #[cfg(any(test, feature = "testing"))] + pub fn for_test_with_peers( + peers: Vec<(schedule::Ciphersuite, dc::Version, SocketAddr)>, + ) -> (Self, Vec) { + let provider = Self::new(Default::default()); + let mut secret = [0; 32]; + aws_lc_rs::rand::fill(&mut secret).unwrap(); + let mut stateless_reset = [0; 16]; + aws_lc_rs::rand::fill(&mut stateless_reset).unwrap(); + + let receiver_shared = receiver::Shared::new(); + + let mut ids = Vec::with_capacity(peers.len()); + for (idx, (ciphersuite, version, peer)) in peers.into_iter().enumerate() { + secret[..8].copy_from_slice(&(idx as u64).to_be_bytes()[..]); + stateless_reset[..8].copy_from_slice(&(idx as u64).to_be_bytes()[..]); + let secret = schedule::Secret::new( + ciphersuite, + version, + s2n_quic_core::endpoint::Type::Client, + &secret, + ); + ids.push(*secret.id()); + let sender = sender::State::new(stateless_reset); + let entry = Entry::new( + peer, + secret, + sender, + receiver_shared.clone().new_receiver(), + testing::test_application_params(), + ); + let entry = Arc::new(entry); + provider.insert(entry); + } + + (provider, ids) + } + + #[doc(hidden)] + #[cfg(any(test, feature = "testing"))] + pub fn test_insert(&self, peer: SocketAddr) { + let mut secret = [0; 32]; + aws_lc_rs::rand::fill(&mut secret).unwrap(); + let secret = schedule::Secret::new( + schedule::Ciphersuite::AES_GCM_128_SHA256, + dc::SUPPORTED_VERSIONS[0], + s2n_quic_core::endpoint::Type::Client, + &secret, + ); + let sender = sender::State::new([0; 16]); + let receiver = self.state.receiver_shared.clone().new_receiver(); + let entry = Entry::new( + peer, + secret, + sender, + receiver, + testing::test_application_params(), + ); + self.insert(Arc::new(entry)); + } + + fn send_control(&self, entry: &Entry, credentials: &Credentials, error: receiver::Error) { + let mut buffer = [0; control::MAX_PACKET_SIZE]; + let buffer = error.to_packet(entry, credentials, &mut buffer); + let dst = entry.peer; + self.send_control_packet(dst, buffer); + } + + pub(crate) fn send_control_packet(&self, dst: SocketAddr, buffer: &[u8]) { + match self.state.control_socket.send_to(buffer, dst) { + Ok(_) => { + // all done + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // ignore would block -- we're not going to queue up control packet messages. + } + Err(e) => { + tracing::warn!("Failed to send control packet to {:?}: {:?}", dst, e); + } + } + } +} + +impl receiver::Error { + pub(super) fn to_packet<'buffer>( + self, + entry: &Entry, + credentials: &Credentials, + buffer: &'buffer mut [u8; control::MAX_PACKET_SIZE], + ) -> &'buffer [u8] { + debug_assert_eq!(entry.secret.id(), &credentials.id); + let encoder = EncoderBuffer::new(&mut buffer[..]); + let length = match self { + receiver::Error::AlreadyExists => control::ReplayDetected { + credential_id: credentials.id, + rejected_key_id: credentials.key_id, + } + .encode(encoder, &entry.secret.control_sealer()), + receiver::Error::Unknown => control::StaleKey { + credential_id: credentials.id, + min_key_id: entry.receiver.minimum_unseen_key_id(), + } + .encode(encoder, &entry.secret.control_sealer()), + }; + &buffer[..length] + } +} + +#[derive(Debug)] +pub(super) struct Entry { + peer: SocketAddr, + secret: schedule::Secret, + retired: IsRetired, + // Last time the entry was pulled out of the State map. + // This is not necessarily the last time the entry was used but it's close enough for our + // purposes: if the entry is not being pulled out of the State map, it's hopefully not going to + // start getting pulled out shortly. This is used for the LRU mechanism, see the Cleaner impl + // for details. + used_at: AtomicU64, + sender: sender::State, + receiver: receiver::State, + parameters: ApplicationParams, +} + +// Retired is 0 if not yet retired. Otherwise it stores the background cleaner epoch at which it +// retired; that epoch increments roughly once per minute. +#[derive(Default)] +struct IsRetired(AtomicU64); + +impl fmt::Debug for IsRetired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IsRetired").field(&self.retired()).finish() + } +} + +impl IsRetired { + fn retired(&self) -> bool { + self.0.load(Ordering::Relaxed) != 0 + } +} + +impl Entry { + pub fn new( + peer: SocketAddr, + secret: schedule::Secret, + sender: sender::State, + receiver: receiver::State, + parameters: ApplicationParams, + ) -> Self { + Self { + peer, + secret, + retired: Default::default(), + used_at: AtomicU64::new(0), + sender, + receiver, + parameters, + } + } + + fn retire(&self, at_epoch: u64) { + self.retired.0.store(at_epoch, Ordering::Relaxed); + } + + fn mark_live(&self, at_epoch: u64) { + self.used_at.store(at_epoch, Ordering::Relaxed); + } + + fn uni_sealer(&self) -> Sealer { + let key_id = self.sender.next_key_id(); + let sealer = self.secret.application_sealer(key_id); + + Sealer { sealer } + } + + fn uni_opener(self: Arc, map: Map, credentials: &Credentials) -> Opener { + let opener = self.secret.application_opener(credentials.key_id); + + let dedup = Dedup::new(self, map); + + Opener { opener, dedup } + } + + fn bidi_local(&self) -> (Sealer, Opener) { + let key_id = self.sender.next_key_id(); + let (sealer, opener) = self.secret.application_pair(key_id, Initiator::Local); + let sealer = Sealer { sealer }; + + // we don't need to dedup locally-initiated openers + let dedup = Dedup::disabled(); + + let opener = Opener { opener, dedup }; + + (sealer, opener) + } + + fn bidi_remote(self: Arc, map: Map, credentials: &Credentials) -> (Sealer, Opener) { + let (sealer, opener) = self + .secret + .application_pair(credentials.key_id, Initiator::Remote); + let sealer = Sealer { sealer }; + + let dedup = Dedup::new(self, map); + + let opener = Opener { opener, dedup }; + + (sealer, opener) + } +} + +pub struct Dedup { + cell: once_cell::sync::OnceCell, + init: core::cell::Cell, Map)>>, +} + +/// SAFETY: `init` cell is synchronized by `OnceCell` +unsafe impl Sync for Dedup {} + +impl Dedup { + #[inline] + fn new(entry: Arc, map: Map) -> Self { + // TODO potentially record a timestamp of when this was created to try and detect long + // delays of processing the first packet. + Self { + cell: Default::default(), + init: core::cell::Cell::new(Some((entry, map))), + } + } + + #[inline] + fn disabled() -> Self { + Self { + cell: once_cell::sync::OnceCell::with_value(Ok(())), + init: core::cell::Cell::new(None), + } + } + + #[inline] + pub(crate) fn disable(&self) { + // TODO + } + + #[inline] + pub fn check(&self, c: &impl crypto::decrypt::Key) -> crypto::decrypt::Result { + *self.cell.get_or_init(|| { + match self.init.take() { + Some((entry, map)) => { + let creds = c.credentials(); + match entry.receiver.post_authentication(creds) { + Ok(()) => Ok(()), + Err(receiver::Error::AlreadyExists) => { + map.send_control(&entry, creds, receiver::Error::AlreadyExists); + Err(crypto::decrypt::Error::ReplayDefinitelyDetected) + } + Err(receiver::Error::Unknown) => { + map.send_control(&entry, creds, receiver::Error::Unknown); + Err(crypto::decrypt::Error::ReplayPotentiallyDetected { + gap: Some( + (*entry.receiver.minimum_unseen_key_id()) + // This should never be negative, but saturate anyway to avoid + // wildly large numbers. + .saturating_sub(*creds.key_id), + ), + }) + } + } + } + None => { + // Dedup has been poisoned! TODO log this + Err(crypto::decrypt::Error::ReplayPotentiallyDetected { gap: None }) + } + } + }) + } +} + +impl fmt::Debug for Dedup { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Dedup").field("cell", &self.cell).finish() + } +} + +pub struct HandshakingPath { + peer: SocketAddr, + dc_version: dc::Version, + parameters: ApplicationParams, + endpoint_type: s2n_quic_core::endpoint::Type, + secret: Option, + map: Map, +} + +impl HandshakingPath { + fn new(connection_info: &dc::ConnectionInfo, map: Map) -> Self { + let endpoint_type = match connection_info.endpoint_type { + EndpointType::Server { .. } => s2n_quic_core::endpoint::Type::Server, + EndpointType::Client { .. } => s2n_quic_core::endpoint::Type::Client, + }; + + Self { + peer: connection_info.remote_address.clone().into(), + dc_version: connection_info.dc_version, + parameters: connection_info.application_params, + endpoint_type, + secret: None, + map, + } + } +} + +impl dc::Endpoint for Map { + type Path = HandshakingPath; + + fn new_path(&mut self, connection_info: &dc::ConnectionInfo) -> Option { + Some(HandshakingPath::new(connection_info, self.clone())) + } +} + +impl dc::Path for HandshakingPath { + fn on_path_secrets_ready( + &mut self, + session: &impl s2n_quic_core::crypto::tls::TlsSession, + ) -> Result, s2n_quic_core::transport::Error> { + let mut material = Zeroizing::new([0; TLS_EXPORTER_LENGTH]); + session + .tls_exporter( + TLS_EXPORTER_LABEL.as_bytes(), + TLS_EXPORTER_CONTEXT.as_bytes(), + &mut *material, + ) + .unwrap(); + + let cipher_suite = match session.cipher_suite() { + s2n_quic_core::crypto::tls::CipherSuite::TLS_AES_128_GCM_SHA256 => { + schedule::Ciphersuite::AES_GCM_128_SHA256 + } + s2n_quic_core::crypto::tls::CipherSuite::TLS_AES_256_GCM_SHA384 => { + schedule::Ciphersuite::AES_GCM_256_SHA384 + } + _ => return Err(s2n_quic_core::transport::Error::INTERNAL_ERROR), + }; + + let secret = + schedule::Secret::new(cipher_suite, self.dc_version, self.endpoint_type, &material); + + let stateless_reset = self.map.signer().sign(secret.id()); + self.secret = Some(secret); + + Ok(vec![stateless_reset.into()]) + } + + fn on_peer_stateless_reset_tokens<'a>( + &mut self, + stateless_reset_tokens: impl Iterator, + ) { + // TODO: support multiple stateless reset tokens + let sender = sender::State::new( + stateless_reset_tokens + .into_iter() + .next() + .unwrap() + .into_inner(), + ); + + let receiver = self.map.state.receiver_shared.clone().new_receiver(); + + let entry = Entry::new( + self.peer, + self.secret + .take() + .expect("peer tokens are only received after secrets are ready"), + sender, + receiver, + self.parameters, + ); + let entry = Arc::new(entry); + self.map.insert(entry); + } +} + +#[cfg(any(test, feature = "testing"))] +pub mod testing { + use s2n_quic_core::{ + connection::Limits, dc::ApplicationParams, transport::parameters::InitialFlowControlLimits, + }; + + pub fn test_application_params() -> ApplicationParams { + ApplicationParams::new( + s2n_quic_core::path::MaxMtu::default().into(), + &InitialFlowControlLimits::default(), + &Limits::default(), + ) + } +} + +#[cfg(test)] +mod test; diff --git a/dc/s2n-quic-dc/src/path/secret/map/test.rs b/dc/s2n-quic-dc/src/path/secret/map/test.rs new file mode 100644 index 0000000000..08f3abafa3 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/map/test.rs @@ -0,0 +1,254 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{receiver, sender}; +use std::{ + collections::HashSet, + net::{Ipv4Addr, SocketAddrV4}, +}; + +use super::*; + +const VERSION: dc::Version = dc::SUPPORTED_VERSIONS[0]; + +fn fake_entry(peer: u16) -> Arc { + let mut secret = [0; 32]; + aws_lc_rs::rand::fill(&mut secret).unwrap(); + Arc::new(Entry::new( + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, peer)), + schedule::Secret::new( + schedule::Ciphersuite::AES_GCM_128_SHA256, + VERSION, + s2n_quic_core::endpoint::Type::Client, + &secret, + ), + sender::State::new([0; 16]), + receiver::State::without_shared(), + super::testing::test_application_params(), + )) +} + +#[test] +fn cleans_after_delay() { + let signer = stateless_reset::Signer::new(b"secret"); + let map = Map::new(signer); + + let first = fake_entry(1); + let second = fake_entry(1); + let third = fake_entry(1); + map.insert(first.clone()); + map.insert(second.clone()); + + let guard = map.state.ids.guard(); + assert!(map.state.ids.contains_key(first.secret.id(), &guard)); + assert!(map.state.ids.contains_key(second.secret.id(), &guard)); + + map.state.cleaner.clean(&map.state, 1); + map.state.cleaner.clean(&map.state, 1); + + map.insert(third.clone()); + + assert!(!map.state.ids.contains_key(first.secret.id(), &guard)); + assert!(map.state.ids.contains_key(second.secret.id(), &guard)); + assert!(map.state.ids.contains_key(third.secret.id(), &guard)); +} + +#[test] +fn thread_shutdown() { + let signer = stateless_reset::Signer::new(b"secret"); + let map = Map::new(signer); + let state = Arc::downgrade(&map.state); + drop(map); + + let iterations = 10; + let max_time = core::time::Duration::from_secs(2); + + for _ in 0..iterations { + // Nothing is holding on to the state, so the thread should shutdown (mpsc disconnects or on + // next loop around if that fails for some reason). + if state.strong_count() == 0 { + return; + } + std::thread::sleep(max_time / iterations); + } + + panic!("thread did not shut down after {max_time:?}"); +} + +#[derive(Debug, Default)] +struct Model { + invariants: HashSet, +} + +#[derive(bolero::TypeGenerator, Debug, Copy, Clone)] +enum Operation { + Insert { ip: u8, path_secret_id: TestId }, + AdvanceTime, + ReceiveUnknown { path_secret_id: TestId }, +} + +#[derive(bolero::TypeGenerator, PartialEq, Eq, Hash, Copy, Clone)] +struct TestId(u8); + +impl fmt::Debug for TestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("TestId") + .field(&self.0) + .field(&self.id()) + .finish() + } +} + +impl TestId { + fn secret(self) -> schedule::Secret { + let mut export_secret = [0; 32]; + export_secret[0] = self.0; + schedule::Secret::new( + schedule::Ciphersuite::AES_GCM_128_SHA256, + VERSION, + s2n_quic_core::endpoint::Type::Client, + &export_secret, + ) + } + + fn id(self) -> Id { + *self.secret().id() + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +enum Invariant { + ContainsIp(SocketAddr), + ContainsId(Id), + IdRemoved(Id), +} + +impl Model { + fn perform(&mut self, operation: Operation, state: &Map) { + match operation { + Operation::Insert { ip, path_secret_id } => { + let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([0, 0, 0, ip]), 0)); + let secret = path_secret_id.secret(); + let id = *secret.id(); + + let stateless_reset = state.state.signer.sign(&id); + state.insert(Arc::new(Entry::new( + ip, + secret, + sender::State::new(stateless_reset), + state.state.receiver_shared.clone().new_receiver(), + super::testing::test_application_params(), + ))); + + self.invariants.insert(Invariant::ContainsIp(ip)); + self.invariants.insert(Invariant::ContainsId(id)); + } + Operation::AdvanceTime => { + let mut invalidated = Vec::new(); + let ids = state.state.ids.guard(); + self.invariants.retain(|invariant| { + if let Invariant::ContainsId(id) = invariant { + if state.state.ids.get(id, &ids).unwrap().retired.retired() { + invalidated.push(*id); + return false; + } + } + + true + }); + for id in invalidated { + assert!(self.invariants.insert(Invariant::IdRemoved(id)), "{id:?}"); + } + + // Evict all stale records *now*. + state.state.cleaner.clean(&state.state, 0); + } + Operation::ReceiveUnknown { path_secret_id } => { + let id = path_secret_id.id(); + // This is signing with the "wrong" signer, but currently all of the signers used + // in this test are keyed the same way so it doesn't matter. + let stateless_reset = state.state.signer.sign(&id); + let packet = + crate::packet::secret_control::unknown_path_secret::Packet::new_for_test( + id, + &stateless_reset, + ); + state.handle_unknown_secret_packet(&packet); + + // ReceiveUnknown does not cause any action with respect to our invariants, no + // updates required. + } + } + } + + fn check_invariants(&self, state: &State) { + let peers = state.peers.guard(); + let ids = state.ids.guard(); + for invariant in self.invariants.iter() { + match invariant { + Invariant::ContainsIp(ip) => { + assert!(state.peers.contains_key(ip, &peers), "{:?}", ip); + } + Invariant::ContainsId(id) => { + assert!(state.ids.contains_key(id, &ids), "{:?}", id); + } + Invariant::IdRemoved(id) => { + assert!( + !state.ids.contains_key(id, &ids), + "{:?}", + state.ids.get(id, &ids) + ); + } + } + } + } +} + +fn has_duplicate_pids(ops: &[Operation]) -> bool { + let mut ids = HashSet::new(); + for op in ops.iter() { + match op { + Operation::Insert { + ip: _, + path_secret_id, + } => { + if !ids.insert(path_secret_id) { + return true; + } + } + Operation::AdvanceTime => {} + Operation::ReceiveUnknown { path_secret_id: _ } => { + // no-op, we're fine receiving unknown pids. + } + } + } + + false +} + +#[test] +fn check_invariants() { + bolero::check!() + .with_type::>() + .with_iterations(100_000) + .for_each(|input: &Vec| { + if has_duplicate_pids(input) { + // Ignore this attempt. + return; + } + + let mut model = Model::default(); + let signer = stateless_reset::Signer::new(b"secret"); + let map = Map::new(signer); + + // Avoid background work interfering with testing. + map.state.cleaner.stop(); + + model.check_invariants(&map.state); + + for op in input { + model.perform(*op, &map); + model.check_invariants(&map.state); + } + }) +} diff --git a/dc/s2n-quic-dc/src/path/secret/receiver.rs b/dc/s2n-quic-dc/src/path/secret/receiver.rs new file mode 100644 index 0000000000..601e1c37f4 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/receiver.rs @@ -0,0 +1,340 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::credentials::{Credentials, Id, KeyId}; +use s2n_quic_core::packet::number::{ + PacketNumber, PacketNumberSpace, SlidingWindow, SlidingWindowError, +}; +use std::{ + cell::UnsafeCell, + ptr::NonNull, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, +}; + +const SHARED_ENTRIES: usize = 1 << 20; +// Maximum page size on current machines (macOS aarch64 has 16kb pages) +// +// mmap is documented as failing if we don't request a page boundary. Currently our sizes work out +// such that rounding is useless, but this is good future proofing. +const MAX_PAGE: usize = 16_384; +const SHARED_ALLOCATION: usize = { + let element = std::mem::size_of::(); + let size = element * SHARED_ENTRIES; + // TODO use `next_multiple_of` once MSRV is >=1.73 + (size + MAX_PAGE - 1) / MAX_PAGE * MAX_PAGE +}; + +#[derive(Debug)] +pub struct Shared { + secret: u64, + backing: NonNull, +} + +unsafe impl Send for Shared {} +unsafe impl Sync for Shared {} + +impl Drop for Shared { + fn drop(&mut self) { + unsafe { + if libc::munmap(self.backing.as_ptr().cast(), SHARED_ALLOCATION) != 0 { + // Avoid panicking in a destructor, just let the memory leak while logging. We + // expect this to be essentially a global singleton in most production cases so + // likely we're exiting the process anyway. + eprintln!( + "Failed to unmap memory: {:?}", + std::io::Error::last_os_error() + ); + } + } + } +} + +const fn assert_copy() {} + +struct SharedSlot { + id: UnsafeCell, + key_id: AtomicU64, +} + +impl SharedSlot { + fn try_lock(&self) -> Option> { + let current = self.key_id.load(Ordering::Relaxed); + if current & LOCK != 0 { + // If we are already locked, then give up. + // A concurrent thread updated this slot, any write we do would squash that thread's + // write. Doing so if that thread remove()d may make sense in the future but not right + // now. + return None; + } + let Ok(_) = self.key_id.compare_exchange( + current, + current | LOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) else { + return None; + }; + + Some(SharedSlotGuard { + slot: self, + key_id: current, + }) + } +} + +struct SharedSlotGuard<'a> { + slot: &'a SharedSlot, + key_id: u64, +} + +impl<'a> SharedSlotGuard<'a> { + fn write_id(&mut self, id: Id) { + // Store the new ID. + // SAFETY: We hold the lock since we are in the guard. + unsafe { + // Note: no destructor is run for the previously stored element, but Id is Copy. + // If we did want to run a destructor we'd have to ensure that we replaced a PRESENT + // entry. + assert_copy::(); + std::ptr::write(self.slot.id.get(), id); + } + } + + fn id(&self) -> Id { + // SAFETY: We hold the lock, so copying out the Id is safe. + unsafe { *self.slot.id.get() } + } +} + +impl<'a> Drop for SharedSlotGuard<'a> { + fn drop(&mut self) { + self.slot.key_id.store(self.key_id, Ordering::Release); + } +} + +const LOCK: u64 = 1 << 62; +const PRESENT: u64 = 1 << 63; + +impl Shared { + pub fn new() -> Arc { + let mut secret = [0; 8]; + aws_lc_rs::rand::fill(&mut secret).expect("random is available"); + let shared = Shared { + secret: u64::from_ne_bytes(secret), + backing: unsafe { + // Note: We rely on the zero-initialization provided by the kernel. That ensures + // that an entry in the map is not LOCK'd to begin with and is not PRESENT as well. + let ptr = libc::mmap( + std::ptr::null_mut(), + SHARED_ALLOCATION, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_ANONYMOUS | libc::MAP_PRIVATE, + 0, + 0, + ); + // -1 + if ptr as usize == usize::MAX { + panic!( + "Failed to allocate backing allocation for shared: {:?}", + std::io::Error::last_os_error() + ); + } + NonNull::new(ptr).unwrap().cast() + }, + }; + + // We need to modify the slot to which an all-zero path secert ID and key ID map. Otherwise + // we'd return Err(AlreadyExists) for that entry which isn't correct - it has not been + // inserted or removed, so it should be Err(Unknown). + // + // This is the only slot that needs modification. All other slots are never used for lookup + // of this set of credentials and so containing this set of credentials is fine. + let slot = shared.slot(&Credentials { + id: Id::from([0; 16]), + key_id: KeyId::new(0).unwrap(), + }); + // The max key ID is never used by senders (checked on the sending side), while avoiding + // taking a full bit out of the range of key IDs. We also statically return Unknown for it + // on removal to avoid a non-local invariant. + slot.key_id.store(KeyId::MAX.as_u64(), Ordering::Relaxed); + + Arc::new(shared) + } + + pub fn new_receiver(self: Arc) -> State { + State::with_shared(self) + } + + fn insert(&self, identity: &Credentials) { + let slot = self.slot(identity); + let Some(mut guard) = slot.try_lock() else { + return; + }; + guard.write_id(identity.id); + guard.key_id = *identity.key_id | PRESENT; + } + + fn remove(&self, identity: &Credentials) -> Result<(), Error> { + // See `new` for details. + if identity.key_id == KeyId::MAX.as_u64() { + return Err(Error::Unknown); + } + + let slot = self.slot(identity); + let previous = slot.key_id.load(Ordering::Relaxed); + if previous & LOCK != 0 { + // If we are already locked, then give up. + // A concurrent thread updated this slot, any write we do would squash that thread's + // write. No concurrent thread could have inserted what we're looking for since + // both insert and remove for a single path secret ID run under a Mutex. + return Err(Error::Unknown); + } + if previous & (!PRESENT) != *identity.key_id { + // If the currently stored entry does not match our desired KeyId, + // then we don't know whether this key has been replayed or not. + return Err(Error::Unknown); + } + + let Some(mut guard) = slot.try_lock() else { + // Don't try to win the race by spinning, let the other thread proceed. + return Err(Error::Unknown); + }; + + // Check if the path secret ID matches. + if guard.id() != identity.id { + return Err(Error::Unknown); + } + + // Ok, at this point we know that the key ID and the path secret ID both match. + + let ret = if guard.key_id & PRESENT != 0 { + Ok(()) + } else { + Err(Error::AlreadyExists) + }; + + // Release the lock, removing the PRESENT bit (which may already be missing). + guard.key_id = *identity.key_id; + + ret + } + + fn index(&self, identity: &Credentials) -> usize { + let hash = u64::from_ne_bytes(identity.id[..8].try_into().unwrap()) + ^ *identity.key_id + ^ self.secret; + let index = hash & (SHARED_ENTRIES as u64 - 1); + index as usize + } + + fn slot(&self, identity: &Credentials) -> &SharedSlot { + let index = self.index(identity); + // SAFETY: in-bounds -- the & above truncates such that we're always in the appropriate + // range that we allocated with mmap above. + // + // Casting to a reference is safe -- the Slot type has an UnsafeCell around all of the data + // (either inside the atomic or directly). + unsafe { self.backing.as_ptr().add(index).as_ref().unwrap_unchecked() } + } +} + +#[derive(Debug)] +pub struct State { + // Minimum that we're potentially willing to accept. + // This is lazily updated and so may be out of date. + min_key_id: AtomicU64, + + // This is the maximum ID we've seen so far. This is sent to peers for when we cannot determine + // if the packet sent is replayed as it falls outside our replay window. Peers use this + // information to resynchronize on the latest state. + max_seen_key_id: AtomicU64, + + seen: Mutex, + + shared: Option>, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, thiserror::Error)] +pub enum Error { + /// This indicates that we know about this element and it *definitely* already exists. + #[error("packet definitely already seen before")] + AlreadyExists, + /// We don't know whether we've seen this element before. It may or may not have already been + /// received. + #[error("packet may have been seen before")] + Unknown, +} + +impl State { + pub fn without_shared() -> State { + State { + min_key_id: Default::default(), + max_seen_key_id: Default::default(), + seen: Default::default(), + shared: None, + } + } + + pub fn with_shared(shared: Arc) -> State { + State { + min_key_id: Default::default(), + max_seen_key_id: Default::default(), + seen: Default::default(), + shared: Some(shared), + } + } + + pub fn pre_authentication(&self, identity: &Credentials) -> Result<(), Error> { + if self.min_key_id.load(Ordering::Relaxed) > *identity.key_id { + return Err(Error::Unknown); + } + + Ok(()) + } + + pub fn minimum_unseen_key_id(&self) -> KeyId { + KeyId::try_from(self.max_seen_key_id.load(Ordering::Relaxed) + 1).unwrap() + } + + /// Called after decryption has been performed + pub fn post_authentication(&self, identity: &Credentials) -> Result<(), Error> { + let key_id = identity.key_id; + self.max_seen_key_id.fetch_max(*key_id, Ordering::Relaxed); + let pn = PacketNumberSpace::Initial.new_packet_number(key_id); + + // Note: intentionally retaining this lock across potential insertion into the shared map. + // This avoids the case where we have evicted an entry but cannot see it in the shared map + // yet from a concurrent thread. This should not be required for correctness but helps + // reasoning about the state of the world. + let mut seen = self.seen.lock().unwrap(); + match seen.insert_with_evicted(pn) { + Ok(evicted) => { + if let Some(shared) = &self.shared { + // FIXME: Consider bounding the number of evicted entries to insert or + // otherwise optimizing? This can run for at most 128 entries today... + for evicted in evicted { + shared.insert(&Credentials { + id: identity.id, + key_id: PacketNumber::as_varint(evicted), + }); + } + } + Ok(()) + } + Err(SlidingWindowError::TooOld) => { + if let Some(shared) = &self.shared { + shared.remove(identity) + } else { + Err(Error::Unknown) + } + } + Err(SlidingWindowError::Duplicate) => Err(Error::AlreadyExists), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/dc/s2n-quic-dc/src/path/secret/receiver/tests.rs b/dc/s2n-quic-dc/src/path/secret/receiver/tests.rs new file mode 100644 index 0000000000..8e49634a2b --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/receiver/tests.rs @@ -0,0 +1,352 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::check; +use rand::{seq::SliceRandom, Rng, SeedableRng}; +use std::collections::{binary_heap::PeekMut, BinaryHeap, HashSet}; + +#[test] +fn check() { + check!().with_type::>().for_each(|ops| { + let mut oracle = std::collections::HashSet::new(); + let subject = State::with_shared(Shared::new()); + let id = Id::from([0; 16]); + for op in ops { + let expected = oracle.insert(*op); + let actual = subject + .post_authentication(&Credentials { id, key_id: *op }) + .is_ok(); + // If we did expect this to be a new value, it may have already been marked as + // "seen" by the set. However, we should never return a false OK (i.e., claim that + // the value was not seen when it actually was). + if !expected { + assert!(!actual); + } + } + }); +} + +#[test] +fn check_ordered() { + check!().with_type::>().for_each(|ops| { + let mut ops = ops.clone(); + ops.sort(); + let mut oracle = std::collections::HashSet::new(); + let subject = State::with_shared(Shared::new()); + let id = Id::from([0; 16]); + for op in ops { + let expected = oracle.insert(op); + let actual = subject + .post_authentication(&Credentials { id, key_id: op }) + .is_ok(); + assert_eq!(actual, expected); + } + }); +} + +#[test] +fn check_u16() { + check!().with_type::>().for_each(|ops| { + let mut oracle = std::collections::HashSet::new(); + let subject = State::with_shared(Shared::new()); + for op in ops { + let op = KeyId::new(*op as u64).unwrap(); + let expected = oracle.insert(op); + let id = Id::from([0; 16]); + let actual = subject + .post_authentication(&Credentials { id, key_id: op }) + .is_ok(); + // If we did expect this to be a new value, it may have already been marked as + // "seen" by the set. However, we should never return a false OK (i.e., claim that + // the value was not seen when it actually was). + // + // Note that despite the u16::MAX < SHARED_ENTRIES, this is still not able to be + // 100% reliable because not all evicted entries from the local set are put into + // the backing allocation. + if !expected { + assert!(!actual); + } + } + }); +} + +#[test] +fn check_ordered_u16() { + check!().with_type::>().for_each(|ops| { + let mut ops = ops.clone(); + ops.sort(); + let mut oracle = std::collections::HashSet::new(); + let subject = State::with_shared(Shared::new()); + let id = Id::from([0; 16]); + for op in ops { + let op = KeyId::new(op as u64).unwrap(); + let expected = oracle.insert(op); + let actual = subject + .post_authentication(&Credentials { id, key_id: op }) + .is_ok(); + assert_eq!(actual, expected); + } + }); +} + +#[test] +fn shared() { + let subject = Shared::new(); + let id1 = Id::from([0; 16]); + let mut id2 = Id::from([0; 16]); + // This is a part of the key ID not used for hashing. + id2[10] = 1; + let key1 = KeyId::new(0).unwrap(); + let key2 = KeyId::new(1).unwrap(); + subject.insert(&Credentials { + id: id1, + key_id: key1, + }); + assert_eq!( + subject.remove(&Credentials { + id: id1, + key_id: key1, + }), + Ok(()) + ); + assert_eq!( + subject.remove(&Credentials { + id: id1, + key_id: key1, + }), + Err(Error::AlreadyExists) + ); + subject.insert(&Credentials { + id: id2, + key_id: key1, + }); + assert_eq!( + subject.remove(&Credentials { + id: id1, + key_id: key1, + }), + Err(Error::Unknown) + ); + assert_eq!( + subject.remove(&Credentials { + id: id1, + key_id: key2, + }), + Err(Error::Unknown) + ); + // Removal never taints an entry, so this is still fine. + assert_eq!( + subject.remove(&Credentials { + id: id2, + key_id: key1, + }), + Ok(()) + ); +} + +// This test is not particularly interesting, it's mostly just the same as the random tests above +// which insert ordered and unordered values. Mostly it tests that we continue to allow 129 IDs of +// arbitrary reordering. +#[test] +fn check_shuffled_chunks() { + check!() + .with_type::<(u64, u8)>() + .for_each(|&(seed, chunk_size)| { + check_shuffled_chunks_inner(seed, chunk_size); + }); +} + +#[test] +fn check_shuffled_chunks_specific() { + check_shuffled_chunks_inner(0xf323243, 10); + check_shuffled_chunks_inner(0xf323243, 63); + check_shuffled_chunks_inner(0xf323243, 129); +} + +fn check_shuffled_chunks_inner(seed: u64, chunk_size: u8) { + eprintln!("======== starting test run ({seed} {chunk_size}) =========="); + if chunk_size == 0 || chunk_size >= 129 { + // Needs at least 1 in the chunk. + // + // Chunk sizes that are larger than the local set are not guaranteed to pass, since they + // may skip entirely over the 129-element window which then isn't inserted at all into our + // backup/shared set. + return; + } + let mut model = Model::default(); + let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); + let mut deltas = (-(chunk_size as i32 / 2)..(chunk_size as i32 / 2)).collect::>(); + for initial in (128u32..100_000u32).step_by(chunk_size as usize) { + deltas.shuffle(&mut rng); + for delta in deltas.iter() { + model.insert(initial.checked_add_signed(*delta).unwrap() as u64); + } + } +} + +// This represents the commonly seen behavior in production where a small percentage of inserted +// keys are potentially significantly delayed. Currently our percentage is fixed, but the delay is +// not; it's minimum is set by our test here and the maximum is always at most SHARED_ENTRIES. +// +// This ensures that in the common case we see in production our receiver map, presuming no +// contention in the shared map, is reliably able to return accurate results. +#[test] +fn check_delayed() { + check!() + .with_type::<(u64, u16)>() + .for_each(|&(seed, delay)| { + check_delayed_inner(seed, delay); + }); +} + +#[test] +fn check_delayed_specific() { + check_delayed_inner(0xf323243, 10); + check_delayed_inner(0xf323243, 63); + check_delayed_inner(0xf323243, 129); +} + +// delay represents the *minimum* delay a delayed entry sees. The maximum is up to SHARED_ENTRIES. +fn check_delayed_inner(seed: u64, delay: u16) { + // We expect that the shared map is always big enough to absorb our delay. + // (This is statically true; u16::MAX < SHARED_ENTRIES). + assert!((delay as usize) < SHARED_ENTRIES); + let delay = delay as u64; + eprintln!("======== starting test run ({seed} {delay}) =========="); + let mut model = Model::default(); + let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); + // reverse the first element (insert_before) to ensure we pop smallest pending ID first. + // max on the second element (id_to_insert) to ensure that we go in least-favorable order if + // there are multiple elements to insert, inserting most recent first and only afterwards older + // entries. + let mut buffered: BinaryHeap<(std::cmp::Reverse, u64)> = BinaryHeap::new(); + for id in 0..(SHARED_ENTRIES as u64 * 3) { + while let Some(peeked) = buffered.peek_mut() { + // min-heap means that if the first entry isn't the one we want, then there's no entry + // that we want. + if (peeked.0).0 == id { + model.insert(peeked.1); + PeekMut::pop(peeked); + } else { + break; + } + } + // Every 128th ID gets put in immediately, the rest are delayed by a random amount. + // This ensures that we always evict all the gaps as we move forward into the backing set. + // In production, this roughly means that at least 1/128 = 0.7% of packets arrive in relative order + // to each other. (That's an approximation, it's not obvious how to really derive a simple + // explanation for what guarantees we're actually trying to provide here). + if id % 128 != 0 { + // ...until some random interval no more than SHARED_ENTRIES away. + let insert_before = rng.gen_range(id + 1 + delay..id + SHARED_ENTRIES as u64); + buffered.push((std::cmp::Reverse(insert_before), id)); + } else { + model.insert(id); + } + } +} + +struct Model { + insert_order: Vec, + oracle: HashSet, + subject: State, +} + +impl Default for Model { + fn default() -> Self { + Self { + oracle: Default::default(), + insert_order: Vec::new(), + subject: State::with_shared(Shared::new()), + } + } +} + +impl Model { + fn insert(&mut self, op: u64) { + let pid = Id::from([0; 16]); + let id = KeyId::new(op).unwrap(); + let expected = self.oracle.insert(op); + if expected { + self.insert_order.push(op); + } + let actual = self.subject.post_authentication(&Credentials { + id: pid, + key_id: id, + }); + if actual.is_ok() != expected { + let mut oracle = self.oracle.iter().collect::>(); + oracle.sort_unstable(); + panic!( + "Inserting {:?} failed, in oracle: {}, in subject: {:?}, inserted: {:?}", + op, expected, actual, self.insert_order + ); + } + } +} + +#[test] +fn shared_no_collisions() { + let mut seen = HashSet::new(); + let shared = Shared::new(); + for key_id in 0..SHARED_ENTRIES as u64 { + let index = shared.index(&Credentials { + id: Id::from([0; 16]), + key_id: KeyId::new(key_id).unwrap(), + }); + assert!(seen.insert(index)); + } + + // The next entry should collide, since we will wrap around. + let index = shared.index(&Credentials { + id: Id::from([0; 16]), + key_id: KeyId::new(SHARED_ENTRIES as u64 + 1).unwrap(), + }); + assert!(!seen.insert(index)); +} + +#[test] +fn shared_id_pair_no_collisions() { + let shared = Shared::new(); + + // Two random IDs. Exact constants shouldn't matter much, we're mainly aiming to test overall + // quality of our mapping from Id + KeyId. + let id1 = Id::from(u128::to_ne_bytes(0x25add729cce683cd0cda41d35436bdc6)); + let id2 = Id::from(u128::to_ne_bytes(0x2862115d0691fe180f2aeb26af3c2e5e)); + + for key_id in 0..SHARED_ENTRIES as u64 { + let index1 = shared.index(&Credentials { + id: id1, + key_id: KeyId::new(key_id).unwrap(), + }); + let index2 = shared.index(&Credentials { + id: id2, + key_id: KeyId::new(key_id).unwrap(), + }); + + // Our path secret IDs are sufficiently different that we expect that for any given index + // we map to a different slot. This test is not *really* saying much since it's highly + // dependent on the exact values of the path secret IDs, but it prevents simple bugs like + // ignoring the IDs entirely. + assert_ne!(index1, index2); + } +} + +// Confirms that we start out without any entries present in the map. +#[test] +fn shared_no_entries() { + let shared = Shared::new(); + // We have to check all slots to be sure. The index used for lookup is going to be shuffled due + // to the hashing in of the secret. We need to use an all-zero path secret ID since the entries + // in the map start out zero-initialized today. + for key_id in 0..SHARED_ENTRIES as u64 { + assert_eq!( + shared.remove(&Credentials { + id: Id::from([0; 16]), + key_id: KeyId::new(key_id).unwrap(), + }), + Err(Error::Unknown) + ); + } +} diff --git a/dc/s2n-quic-dc/src/path/secret/schedule.rs b/dc/s2n-quic-dc/src/path/secret/schedule.rs new file mode 100644 index 0000000000..bdc5ce484d --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/schedule.rs @@ -0,0 +1,289 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + credentials::{Credentials, Id}, + crypto::awslc::{DecryptKey, EncryptKey}, +}; +use aws_lc_rs::{ + aead::{self, NONCE_LEN}, + hkdf::{self, Prk}, +}; +use s2n_quic_core::{dc, varint::VarInt}; + +pub use s2n_quic_core::endpoint; +pub const MAX_KEY_LEN: usize = 32; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[allow(non_camel_case_types)] +pub enum Ciphersuite { + AES_GCM_128_SHA256, + #[allow(dead_code)] + AES_GCM_256_SHA384, +} + +impl Ciphersuite { + #[inline] + pub fn aead(&self) -> &'static aead::Algorithm { + match self { + Self::AES_GCM_128_SHA256 => &aead::AES_128_GCM, + Self::AES_GCM_256_SHA384 => &aead::AES_256_GCM, + } + } + + #[inline] + pub fn hkdf(&self) -> hkdf::Algorithm { + match self { + Self::AES_GCM_128_SHA256 => hkdf::HKDF_SHA256, + Self::AES_GCM_256_SHA384 => hkdf::HKDF_SHA384, + } + } +} + +impl hkdf::KeyType for Ciphersuite { + #[inline] + fn len(&self) -> usize { + match self { + Self::AES_GCM_128_SHA256 => 16, + Self::AES_GCM_256_SHA384 => 32, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum Initiator { + Local, + Remote, +} + +impl Initiator { + #[inline] + fn label(self, endpoint: endpoint::Type) -> &'static [u8] { + use endpoint::Type::*; + use Initiator::*; + + match (endpoint, self) { + (Client, Local) | (Server, Remote) => b" client", + (Server, Local) | (Client, Remote) => b" server", + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum Direction { + Send, + Receive, +} + +impl Direction { + #[inline] + fn label(self, endpoint: endpoint::Type) -> &'static [u8] { + use endpoint::Type::*; + use Direction::*; + + match (endpoint, self) { + (Client, Send) | (Server, Receive) => b" client", + (Server, Send) | (Client, Receive) => b" server", + } + } +} + +pub const EXPORT_SECRET_LEN: usize = 32; +pub type ExportSecret = [u8; 32]; + +#[derive(Debug)] +pub struct Secret { + id: Id, + prk: Prk, + endpoint: endpoint::Type, + ciphersuite: Ciphersuite, +} + +impl Secret { + #[inline] + pub fn new( + ciphersuite: Ciphersuite, + _version: dc::Version, + endpoint: endpoint::Type, + export_secret: &ExportSecret, + ) -> Self { + let prk = Prk::new_less_safe(ciphersuite.hkdf(), export_secret); + + let mut v = Self { + id: Default::default(), + prk, + endpoint, + ciphersuite, + }; + + let mut id = Id::default(); + v.expand(&[&[16], b" pid"], &mut *id); + v.id = id; + + v + } + + #[inline] + pub fn id(&self) -> &Id { + &self.id + } + + #[inline] + pub fn application_pair( + &self, + key_id: VarInt, + initiator: Initiator, + ) -> (EncryptKey, DecryptKey) { + let creds = Credentials { + id: self.id, + key_id, + }; + + let ciphersuite = &self.ciphersuite; + let mut out = [0u8; (NONCE_LEN + MAX_KEY_LEN) * 2]; + let key_len = hkdf::KeyType::len(ciphersuite); + let out_len = (NONCE_LEN + key_len) * 2; + let (out, _) = out.split_at_mut(out_len); + self.expand( + &[ + &[out_len as u8], + b" bidi", + initiator.label(self.endpoint), + &key_id.to_be_bytes(), + ], + out, + ); + // if the hash is ever broken, it's better to put the "more secret" data at the beginning + // + // here we derive + // + // (client_key, server_key, client_iv, server_iv) + let (client_key, out) = out.split_at(key_len); + let (server_key, out) = out.split_at(key_len); + let (client_iv, server_iv) = out.split_at(NONCE_LEN); + let client_iv = client_iv.try_into().unwrap(); + let server_iv = server_iv.try_into().unwrap(); + let aead = ciphersuite.aead(); + + match self.endpoint { + endpoint::Type::Client => { + let sealer = EncryptKey::new(creds, client_key, client_iv, aead); + let opener = DecryptKey::new(creds, server_key, server_iv, aead); + (sealer, opener) + } + endpoint::Type::Server => { + let sealer = EncryptKey::new(creds, server_key, server_iv, aead); + let opener = DecryptKey::new(creds, client_key, client_iv, aead); + (sealer, opener) + } + } + } + + #[inline] + pub fn application_sealer(&self, key_id: VarInt) -> EncryptKey { + let creds = Credentials { + id: self.id, + key_id, + }; + + self.derive_application_key(Direction::Send, key_id, |alg, key, iv| { + EncryptKey::new(creds, key, iv, alg) + }) + } + + #[inline] + pub fn application_opener(&self, key_id: VarInt) -> DecryptKey { + let creds = Credentials { + id: self.id, + key_id, + }; + + self.derive_application_key(Direction::Receive, key_id, |alg, key, iv| { + DecryptKey::new(creds, key, iv, alg) + }) + } + + #[inline] + fn derive_application_key(&self, direction: Direction, key_id: VarInt, f: F) -> R + where + F: FnOnce(&'static aead::Algorithm, &[u8], [u8; NONCE_LEN]) -> R, + { + let mut out = [0u8; NONCE_LEN + MAX_KEY_LEN]; + let key_len = hkdf::KeyType::len(&self.ciphersuite); + let out_len = NONCE_LEN + key_len; + let (out, _) = out.split_at_mut(out_len); + self.expand( + &[ + &[out_len as u8], + b" uni", + direction.label(self.endpoint), + &key_id.to_be_bytes(), + ], + out, + ); + // if the hash is ever broken, it's better to put the "more secret" data at the beginning + let (key, iv) = out.split_at(key_len); + let iv = iv.try_into().unwrap(); + f(self.ciphersuite.aead(), key, iv) + } + + pub fn control_sealer(&self) -> EncryptKey { + let creds = Credentials { + id: *self.id(), + key_id: VarInt::ZERO, + }; + + self.derive_control_key(Direction::Send, |alg, key, iv| { + EncryptKey::new(creds, key, iv, alg) + }) + } + + pub fn control_opener(&self) -> DecryptKey { + let creds = Credentials { + id: *self.id(), + key_id: VarInt::ZERO, + }; + + self.derive_control_key(Direction::Receive, |alg, key, iv| { + DecryptKey::new(creds, key, iv, alg) + }) + } + + #[inline] + fn derive_control_key(&self, direction: Direction, f: F) -> R + where + F: FnOnce(&'static aead::Algorithm, &[u8], [u8; NONCE_LEN]) -> R, + { + let mut out = [0u8; NONCE_LEN + MAX_KEY_LEN]; + let key_len = hkdf::KeyType::len(&self.ciphersuite); + let out_len = NONCE_LEN + key_len; + let (out, _) = out.split_at_mut(out_len); + self.expand( + &[&[out_len as u8], b" ctl", direction.label(self.endpoint)], + out, + ); + // if the hash is ever broken, it's better to put the "more secret" data at the beginning + let (key, iv) = out.split_at(key_len); + let iv = iv.try_into().unwrap(); + f(self.ciphersuite.aead(), key, iv) + } + + #[inline] + fn expand(&self, label: &[&[u8]], out: &mut [u8]) { + self.prk + .expand(label, OutLen(out.len())) + .unwrap() + .fill(out) + .unwrap(); + } +} + +#[derive(Clone, Copy)] +pub struct OutLen(pub usize); + +impl hkdf::KeyType for OutLen { + #[inline] + fn len(&self) -> usize { + self.0 + } +} diff --git a/dc/s2n-quic-dc/src/path/secret/sender.rs b/dc/s2n-quic-dc/src/path/secret/sender.rs new file mode 100644 index 0000000000..3f50eb8617 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/sender.rs @@ -0,0 +1,89 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::schedule; +use crate::crypto::awslc::DecryptKey; +use once_cell::sync::OnceCell; +use s2n_quic_core::varint::VarInt; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[derive(Debug)] +pub struct State { + current_id: AtomicU64, + pub(super) stateless_reset: [u8; 16], + control_secret: OnceCell, +} + +impl State { + pub fn new(stateless_reset: [u8; 16]) -> Self { + Self { + current_id: AtomicU64::new(0), + stateless_reset, + control_secret: Default::default(), + } + } + + pub fn next_key_id(&self) -> VarInt { + let id = self + .current_id + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + VarInt::try_from(current + 1) + .ok() + // Make sure we can always +1. This is a useful property for StaleKey packets + // which send a minimum *not yet seen* ID. In practice it shouldn't matter + // since we are assuming we can't hit 2^62, but this helps localize handling + // that edge to this code. + .filter(|id| *id != VarInt::MAX) + .map(|id| *id) + }); + + let id = id.expect("2^62 integer incremented per-path will not wrap"); + + // The atomic will not be incremented (i.e., would have panic'd above) if we do not fit + // into a VarInt. + VarInt::try_from(id).unwrap() + } + + #[inline] + pub fn control_secret(&self, secret: &schedule::Secret) -> &DecryptKey { + self.control_secret.get_or_init(|| secret.control_opener()) + } + + /// Update the sender for a received stale key packet. + /// + /// This increments the current ID we are sending at to at least the ID provided in the packet. + /// + /// Note that this packet can be replayed without detection, we must deal with authenticated + /// but arbitrarily old IDs here. In the future we may want to guard against advancing too + /// quickly (e.g., due to bit flips), but for now we ignore that problem. + pub(super) fn update_for_stale_key(&self, min_key_id: VarInt) { + // Update the key to the new minimum to start at. + self.current_id.fetch_max(*min_key_id, Ordering::Relaxed); + } +} + +#[test] +#[should_panic = "2^62 integer incremented"] +fn sender_does_not_wrap() { + let state = State::new([0; 16]); + assert_eq!(*state.next_key_id(), 0); + + state.current_id.store((1 << 62) - 3, Ordering::Relaxed); + + assert_eq!(*state.next_key_id(), (1 << 62) - 3); + assert_eq!(*state.next_key_id(), (1 << 62) - 2); + assert_eq!(*state.next_key_id(), (1 << 62) - 1); + // should panic + state.next_key_id(); +} + +#[test] +fn update_restarts_sequence() { + let state = State::new([0; 16]); + assert_eq!(*state.next_key_id(), 0); + + state.update_for_stale_key(VarInt::new(3).unwrap()); + + // Update should start at the minimum trusted key ID on the other side. + assert_eq!(*state.next_key_id(), 3); +} diff --git a/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs new file mode 100644 index 0000000000..f8e4a9bc76 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/secret/stateless_reset.rs @@ -0,0 +1,38 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::schedule; +use crate::credentials::Id; +use aws_lc_rs::hkdf::{Prk, Salt, HKDF_SHA384}; + +#[derive(Debug)] +pub struct Signer { + prk: Prk, +} + +impl Default for Signer { + fn default() -> Self { + let mut secret = [0u8; 32]; + aws_lc_rs::rand::fill(&mut secret).unwrap(); + Self::new(&secret) + } +} + +impl Signer { + pub fn new(secret: &[u8]) -> Self { + let prk = Salt::new(HKDF_SHA384, secret).extract(b"rst"); + Self { prk } + } + + pub fn sign(&self, id: &Id) -> [u8; 16] { + let mut stateless_reset = [0; 16]; + + self.prk + .expand(&[&[16], b"rst ", &**id], schedule::OutLen(16)) + .unwrap() + .fill(&mut stateless_reset) + .unwrap(); + + stateless_reset + } +} diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index ec7c660529..84484bb9e7 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -13,6 +13,7 @@ pub mod packet_number; pub mod processing; pub mod recv; pub mod send; +pub mod server; bitflags::bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/dc/s2n-quic-dc/src/stream/recv.rs b/dc/s2n-quic-dc/src/stream/recv.rs index 892a8268ca..6d656cd717 100644 --- a/dc/s2n-quic-dc/src/stream/recv.rs +++ b/dc/s2n-quic-dc/src/stream/recv.rs @@ -1,17 +1,17 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use super::TransportFeatures; +use super::{TransportFeatures, DEFAULT_IDLE_TIMEOUT}; use crate::{ allocator::Allocator, crypto::{decrypt, encrypt, UninitSlice}, packet::{control, stream}, - path::Parameters, }; use core::{task::Poll, time::Duration}; use s2n_codec::{EncoderBuffer, EncoderValue}; use s2n_quic_core::{ buffer::{self, reader::storage::Infallible as _}, + dc::ApplicationParams, ensure, frame::{self, ack::EcnCounts}, inet::ExplicitCongestionNotification, @@ -54,7 +54,11 @@ pub struct Receiver { impl Receiver { #[inline] - pub fn new(stream_id: stream::Id, params: &Parameters, features: TransportFeatures) -> Self { + pub fn new( + stream_id: stream::Id, + params: &ApplicationParams, + features: TransportFeatures, + ) -> Self { let initial_max_data = params.local_recv_max_data; Self { stream_id, @@ -64,7 +68,7 @@ impl Receiver { recovery_ack: Default::default(), state: Default::default(), idle_timer: Default::default(), - idle_timeout: params.idle_timeout(), + idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), tick_timer: Default::default(), _should_transmit: false, max_data: initial_max_data, diff --git a/dc/s2n-quic-dc/src/stream/send/worker.rs b/dc/s2n-quic-dc/src/stream/send/worker.rs index 80213b4937..8900970a2c 100644 --- a/dc/s2n-quic-dc/src/stream/send/worker.rs +++ b/dc/s2n-quic-dc/src/stream/send/worker.rs @@ -8,7 +8,6 @@ use crate::{ self, stream::{self, decoder, encoder}, }, - path::Parameters, recovery, stream::{ processing, @@ -16,11 +15,13 @@ use crate::{ application, buffer, error::Error, filter::Filter, probes, transmission::Type as TransmissionType, }, + DEFAULT_IDLE_TIMEOUT, }, }; use core::{task::Poll, time::Duration}; use s2n_codec::{DecoderBufferMut, EncoderBuffer}; use s2n_quic_core::{ + dc::ApplicationParams, ensure, frame::{self, FrameMut}, inet::ExplicitCongestionNotification, @@ -112,8 +113,8 @@ pub struct PeerActivity { impl Worker { #[inline] - pub fn new(stream_id: stream::Id, params: &Parameters) -> Self { - let mtu = params.max_mtu; + pub fn new(stream_id: stream::Id, params: &ApplicationParams) -> Self { + let mtu = params.max_datagram_size; let initial_max_data = params.remote_max_data; let local_max_data = params.local_send_max_data; @@ -121,7 +122,7 @@ impl Worker { let mut unacked_ranges = IntervalSet::new(); unacked_ranges.insert(VarInt::ZERO..=VarInt::MAX).unwrap(); - let cca = congestion::Controller::new(mtu.into()); + let cca = congestion::Controller::new(mtu); let max_sent_offset = VarInt::ZERO; Self { @@ -146,14 +147,14 @@ impl Worker { pto_backoff: INITIAL_PTO_BACKOFF, inflight_timer: Default::default(), idle_timer: Default::default(), - idle_timeout: params.idle_timeout(), + idle_timeout: params.max_idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT), error: None, unacked_ranges, max_sent_offset, max_data: initial_max_data, local_max_data_window: local_max_data, peer_activity: None, - mtu: mtu.into(), + mtu, max_sent_segment_size: 0, } } diff --git a/dc/s2n-quic-dc/src/stream/send/worker/checker.rs b/dc/s2n-quic-dc/src/stream/send/worker/checker.rs deleted file mode 100644 index 1db1f54532..0000000000 --- a/dc/s2n-quic-dc/src/stream/send/worker/checker.rs +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -#![cfg_attr(not(debug_assertions), allow(dead_code, unused_imports))] - -use s2n_quic_core::{buffer::Reader, interval_set::IntervalSet, varint::VarInt}; - -#[cfg(debug_assertions)] -macro_rules! run { - ($($tt:tt)*) => { - $($tt)* - } -} - -#[cfg(not(debug_assertions))] -macro_rules! run { - ($($tt:tt)*) => {}; -} - -#[cfg(debug_assertions)] -#[derive(Clone, Debug, Default)] -pub struct Checker { - acked_ranges: IntervalSet, - largest_transmitted_offset: VarInt, - max_data: VarInt, - highest_seen_offset: Option, - final_offset: Option, -} - -#[cfg(not(debug_assertions))] -#[derive(Clone, Debug, Default)] -pub struct Checker {} - -#[allow(unused_variables)] -impl Checker { - #[inline(always)] - pub fn check_payload(&mut self, payload: &impl Reader) { - run!({ - if let Some(final_offset) = payload.final_offset() { - self.on_final_offset(final_offset); - } - self.on_stream_offset( - payload.current_offset(), - payload.buffered_len().min(u16::MAX as _) as _, - ); - }); - } - - #[inline(always)] - pub fn on_ack(&mut self, offset: VarInt, payload_len: u16) { - run!(if payload_len > 0 { - self.acked_ranges - .insert(offset..offset + VarInt::from_u16(payload_len)) - .unwrap(); - }); - } - - #[inline(always)] - pub fn on_max_data(&mut self, max_data: VarInt) { - run!({ - self.max_data = self.max_data.max(max_data); - }); - } - - #[inline(always)] - pub fn check_pending_packets( - &self, - packets: &super::PacketMap, - retransmissions: &super::BinaryHeap>, - ) { - run!({ - let largest_transmitted_offset = self.largest_transmitted_offset; - if largest_transmitted_offset == 0u64 { - return; - } - - let mut missing = IntervalSet::new(); - missing - .insert(VarInt::ZERO..largest_transmitted_offset) - .unwrap(); - // remove all of the ranges we've acked - missing.difference(&self.acked_ranges).unwrap(); - - for (_pn, packet) in packets.iter() { - let offset = packet.data.stream_offset; - let payload_len = packet.data.payload_len; - if payload_len > 0 { - missing - .remove(offset..offset + VarInt::from_u16(payload_len)) - .unwrap(); - } - } - - for packet in retransmissions.iter() { - let offset = packet.stream_offset; - let payload_len = packet.payload_len; - if payload_len > 0 { - missing - .remove(offset..offset + VarInt::from_u16(payload_len)) - .unwrap(); - } - } - - assert!( - missing.is_empty(), - "missing ranges for retransmission {missing:?}" - ); - }); - } - - #[inline(always)] - pub fn on_stream_transmission( - &mut self, - offset: VarInt, - payload_len: u16, - is_retransmission: bool, - is_probe: bool, - ) { - run!({ - self.on_stream_offset(offset, payload_len); - - if !is_retransmission && !is_probe { - assert_eq!(self.largest_transmitted_offset, offset); - } - - let end_offset = offset + VarInt::from_u16(payload_len); - self.largest_transmitted_offset = self.largest_transmitted_offset.max(end_offset); - - assert!(self.largest_transmitted_offset <= self.max_data); - }); - } - - #[inline(always)] - pub fn on_stream_offset(&mut self, offset: VarInt, payload_len: u16) { - run!({ - if let Some(final_offset) = self.final_offset { - assert!(offset <= final_offset); - } - - match self.highest_seen_offset.as_mut() { - Some(prev) => *prev = (*prev).max(offset), - None => self.highest_seen_offset = Some(offset), - } - }); - } - - #[inline(always)] - fn on_final_offset(&mut self, final_offset: VarInt) { - run!({ - self.on_stream_offset(final_offset, 0); - - match self.final_offset { - Some(prev) => assert_eq!(prev, final_offset), - None => self.final_offset = Some(final_offset), - } - }); - } -} diff --git a/dc/s2n-quic-dc/src/stream/server.rs b/dc/s2n-quic-dc/src/stream/server.rs new file mode 100644 index 0000000000..ae6046f094 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server.rs @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::type_complexity)] + +use crate::{credentials::Credentials, msg::recv, packet}; +use s2n_codec::{DecoderBufferMut, DecoderError}; + +pub mod handshake; + +#[derive(Debug)] +pub struct InitialPacket { + pub credentials: Credentials, + pub stream_id: packet::stream::Id, + pub source_control_port: u16, + pub source_stream_port: Option, + pub payload_len: usize, + pub is_zero_offset: bool, + pub is_retransmission: bool, + pub is_fin: bool, +} + +impl InitialPacket { + #[inline] + pub fn peek(recv: &mut recv::Message, tag_len: usize) -> Result { + let segment = recv + .peek_segments() + .next() + .ok_or(DecoderError::UnexpectedEof(1))?; + + let decoder = DecoderBufferMut::new(segment); + // we're just going to assume that all of the packets in this datagram + // pertain to the same stream + let (packet, _remaining) = decoder.decode_parameterized(tag_len)?; + + let packet::Packet::Stream(packet) = packet else { + return Err(DecoderError::InvariantViolation("unexpected packet type")); + }; + + let packet: InitialPacket = packet.into(); + + Ok(packet) + } +} + +impl<'a> From> for InitialPacket { + #[inline] + fn from(packet: packet::stream::decoder::Packet<'a>) -> Self { + let credentials = *packet.credentials(); + let stream_id = *packet.stream_id(); + let source_control_port = packet.source_control_port(); + let source_stream_port = packet.source_stream_port(); + let payload_len = packet.payload().len(); + let is_zero_offset = packet.stream_offset().as_u64() == 0; + let is_retransmission = packet.is_retransmission(); + let is_fin = packet.is_fin(); + Self { + credentials, + stream_id, + source_control_port, + source_stream_port, + is_zero_offset, + payload_len, + is_retransmission, + is_fin, + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/server/handshake.rs b/dc/s2n-quic-dc/src/stream/server/handshake.rs new file mode 100644 index 0000000000..0c84a66785 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/handshake.rs @@ -0,0 +1,111 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{credentials, msg::recv}; +use core::task::{Context, Poll}; +use std::sync::{Arc, Weak}; +use tokio::sync::mpsc; + +type Sender = mpsc::Sender; +type ReceiverChan = mpsc::Receiver; +type Key = (credentials::Id, u64); +type HashMap = flurry::HashMap; + +pub enum Outcome { + Forwarded, + Created { receiver: Receiver }, +} + +pub struct Map { + inner: Arc, + next: Option<(Sender, ReceiverChan)>, + channel_size: usize, +} + +impl Default for Map { + #[inline] + fn default() -> Self { + Self { + inner: Default::default(), + next: None, + channel_size: 15, + } + } +} + +impl Map { + #[inline] + pub fn handle(&mut self, packet: &super::InitialPacket, msg: &mut recv::Message) -> Outcome { + let stream_id = packet.stream_id.into_varint().as_u64(); + let (sender, receiver) = self + .next + .take() + .unwrap_or_else(|| mpsc::channel(self.channel_size)); + + let key = (packet.credentials.id, stream_id); + + let guard = self.inner.guard(); + match self.inner.try_insert(key, sender, &guard) { + Ok(_) => { + drop(guard); + let map = Arc::downgrade(&self.inner); + tracing::trace!(action = "register", credentials = ?&key.0, stream_id = key.1); + let receiver = ReceiverState { + map, + key, + channel: receiver, + }; + let receiver = Receiver(Box::new(receiver)); + Outcome::Created { receiver } + } + Err(err) => { + self.next = Some((err.not_inserted, receiver)); + + tracing::trace!(action = "forward", credentials = ?&key.0, stream_id = key.1); + if let Err(err) = err.current.try_send(msg.take()) { + match err { + mpsc::error::TrySendError::Closed(_) => { + // remove the channel from the map since we're closed + self.inner.remove(&key, &guard); + tracing::debug!(stream_id, error = "channel_closed"); + } + mpsc::error::TrySendError::Full(_) => { + // drop the packet + let _ = msg; + tracing::debug!(stream_id, error = "channel_full"); + } + } + } + + Outcome::Forwarded + } + } + } +} + +#[derive(Debug)] +pub struct Receiver(Box); + +#[derive(Debug)] +struct ReceiverState { + map: Weak, + key: Key, + channel: ReceiverChan, +} + +impl Receiver { + #[inline] + pub fn poll_recv(&mut self, cx: &mut Context) -> Poll> { + self.0.channel.poll_recv(cx) + } +} + +impl Drop for Receiver { + #[inline] + fn drop(&mut self) { + if let Some(map) = self.0.map.upgrade() { + tracing::trace!(action = "unregister", credentials = ?&self.0.key.0, stream_id = self.0.key.1); + let _ = map.remove(&self.0.key, &map.guard()); + } + } +}