From 9279196177bd0716e436ea23ba4bc59533590b3e Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 22:59:04 -0700 Subject: [PATCH] Fix documented unsound code in serde impl. --- src/serde/de.rs | 46 ++++++++++++++++++++++- src/serde/ser.rs | 98 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 8 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 3cb597f..94ba782 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -401,6 +401,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>), input: &'a mut &'de [u8], len: usize, + key_deserialized: bool, } impl<'de> MapAccess<'de> for Access<'_, 'de> { @@ -414,6 +415,9 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { guard_zst::(self.len)?; if self.len != 0 { self.len -= 1; + // Safety: Make sure next_value_seed is called at most once after each len decrement. + // We don't care if DeserializeSeed fails after this (not critical to safety). + self.key_deserialized = true; Ok(Some(DeserializeSeed::deserialize( seed, DecoderWrapper { @@ -426,12 +430,17 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { } } - // TODO(unsound): could be called more than len times by buggy safe code and go out of bounds. #[inline(always)] fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, { + // Safety: Make sure next_value_seed is called at most once after each len decrement + // since only len values exist. + assert!( + std::mem::take(&mut self.key_deserialized), + "next_value_seed before next_key_seed" + ); DeserializeSeed::deserialize( seed, DecoderWrapper { @@ -440,6 +449,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { }, ) } + // TODO implement next_entry_seed to avoid checking key_deserialized. #[inline(always)] fn size_hint(&self) -> Option { @@ -451,6 +461,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { decoders, input: self.input, len, + key_deserialized: false, // No keys have been deserialized yet, so next_value_seed can't be called. }) } @@ -561,6 +572,8 @@ impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { #[cfg(test)] mod tests { + use serde::de::MapAccess; + use serde::Deserializer; use std::collections::BTreeMap; #[test] @@ -621,4 +634,35 @@ mod tests { // Complex. test!(vec![(None, 3), (Some(4), 5)], Vec<(Option, u8)>); } + + #[test] + #[should_panic = "next_value_seed before next_key_seed"] + fn map_incorrect_len_values() { + let mut map = BTreeMap::new(); + map.insert(1u8, 2u8); + let input = crate::serialize(&map).unwrap(); + + let w = super::DecoderWrapper { + decoder: &mut super::SerdeDecoder::Unspecified { length: 1 }, + input: &mut input.as_slice(), + }; + + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = (); + fn expecting(&self, _: &mut std::fmt::Formatter) -> std::fmt::Result { + unreachable!() + } + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + assert_eq!(map.next_key::().unwrap().unwrap(), 1u8); + assert_eq!(map.next_value::().unwrap(), 2u8); + map.next_value::().unwrap(); + Ok(()) + } + } + w.deserialize_map(Visitor).unwrap(); + } } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 9af687f..19c7c82 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -262,7 +262,7 @@ macro_rules! impl_ser { impl<'a> Serializer for EncoderWrapper<'a> { type Ok = (); type Error = Error; - type SerializeSeq = EncoderWrapper<'a>; + type SerializeSeq = SeqSerializer<'a>; type SerializeTuple = TupleSerializer<'a>; type SerializeTupleStruct = TupleSerializer<'a>; type SerializeTupleVariant = TupleSerializer<'a>; @@ -357,9 +357,10 @@ impl<'a> Serializer for EncoderWrapper<'a> { let b = specify!(self, Seq); b.0.encode(&len); b.1.reserve_fast(len); - Ok(Self { + Ok(SeqSerializer { lazy: &mut b.1, index_alloc: self.index_alloc, + len, }) } @@ -446,6 +447,8 @@ impl<'a> Serializer for EncoderWrapper<'a> { Ok(MapSerializer { encoders: &mut b.1, index_alloc: self.index_alloc, + len, + key_serialized: false, // No keys have been serialized yet, so serialize_value can't be called. }) } @@ -481,11 +484,18 @@ macro_rules! ok_error_end { }; } -impl SerializeSeq for EncoderWrapper<'_> { +struct SeqSerializer<'a> { + lazy: &'a mut LazyEncoder, + index_alloc: &'a mut usize, + len: usize, +} + +impl SerializeSeq for SeqSerializer<'_> { ok_error_end!(); - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_element(&mut self, value: &T) -> Result<()> { + // Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len elements. + self.len = self.len.checked_sub(1).expect("length mismatch"); value.serialize(EncoderWrapper { lazy: &mut *self.lazy, index_alloc: &mut *self.index_alloc, @@ -531,39 +541,50 @@ impl_tuple!(SerializeStructVariant, serialize_field, _key); struct MapSerializer<'a> { encoders: &'a mut (LazyEncoder, LazyEncoder), // (keys, values) index_alloc: &'a mut usize, + len: usize, + key_serialized: bool, } impl SerializeMap for MapSerializer<'_> { ok_error_end!(); - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_key(&mut self, key: &T) -> Result<()> where T: Serialize, { + // Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len keys/values. + self.len = self.len.checked_sub(1).expect("length mismatch"); + // Safety: Make sure serialize_value is called at most once after each serialize_key. + self.key_serialized = true; key.serialize(EncoderWrapper { lazy: &mut self.encoders.0, index_alloc: &mut *self.index_alloc, }) } - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_value(&mut self, value: &T) -> Result<()> where T: Serialize, { + // Safety: Make sure serialize_value is called at most once after each serialize_key. + assert!( + std::mem::take(&mut self.key_serialized), + "serialize_value before serialize_key" + ); value.serialize(EncoderWrapper { lazy: &mut self.encoders.1, index_alloc: &mut *self.index_alloc, }) } + // TODO implement serialize_entry to avoid checking key_serialized. } #[cfg(test)] mod tests { - use serde::ser::SerializeTuple; + use serde::ser::{SerializeMap, SerializeSeq, SerializeTuple}; use serde::{Serialize, Serializer}; + use std::num::NonZeroUsize; #[test] fn enum_256_variants() { @@ -613,4 +634,67 @@ mod tests { } let _ = crate::serialize(&vec![TupleN(1), TupleN(2)]); } + + // Has to be a macro because it borrows something on the stack and returns it. + macro_rules! new_wrapper { + () => { + super::EncoderWrapper { + lazy: &mut super::LazyEncoder::Unspecified { + reserved: NonZeroUsize::new(1), + }, + index_alloc: &mut 0, + } + }; + } + + #[test] + fn seq_valid() { + let w = new_wrapper!(); + let mut seq = w.serialize_seq(Some(1)).unwrap(); + let _ = seq.serialize_element(&0u8); // serialize_seq 1 == serialize 1. + } + + #[test] + #[should_panic = "length mismatch"] + fn seq_incorrect_len() { + let w = new_wrapper!(); + let mut seq = w.serialize_seq(Some(1)).unwrap(); + let _ = seq.serialize_element(&0u8); // serialize_seq 1 != serialize 2. + let _ = seq.serialize_element(&0u8); + } + + #[test] + fn map_valid() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 == (key, value). + let _ = map.serialize_value(&0u8); + } + + #[test] + #[should_panic = "length mismatch"] + fn map_incorrect_len_keys() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, _) (key, _) + let _ = map.serialize_key(&0u8); + } + + #[test] + #[should_panic = "serialize_value before serialize_key"] + fn map_value_before_key() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_value(&0u8); + } + + #[test] + #[should_panic = "serialize_value before serialize_key"] + fn map_incorrect_len_values() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, value) (_, value). + let _ = map.serialize_value(&0u8); + let _ = map.serialize_value(&0u8); + } }