Skip to content

Commit

Permalink
Fix documented unsound code in serde impl.
Browse files Browse the repository at this point in the history
  • Loading branch information
caibear committed Mar 14, 2024
1 parent 7f2104c commit 9279196
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 8 deletions.
46 changes: 45 additions & 1 deletion src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>),
input: &'a mut &'de [u8],
len: usize,
key_deserialized: bool,
}

impl<'de> MapAccess<'de> for Access<'_, 'de> {
Expand All @@ -414,6 +415,9 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
guard_zst::<K::Value>(self.len)?;
if self.len != 0 {
self.len -= 1;
// Safety: Make sure next_value_seed is called at most once after each len decrement.
// We don't care if DeserializeSeed fails after this (not critical to safety).
self.key_deserialized = true;
Ok(Some(DeserializeSeed::deserialize(
seed,
DecoderWrapper {
Expand All @@ -426,12 +430,17 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
}
}

// TODO(unsound): could be called more than len times by buggy safe code and go out of bounds.
#[inline(always)]
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
// Safety: Make sure next_value_seed is called at most once after each len decrement
// since only len values exist.
assert!(
std::mem::take(&mut self.key_deserialized),
"next_value_seed before next_key_seed"
);
DeserializeSeed::deserialize(
seed,
DecoderWrapper {
Expand All @@ -440,6 +449,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
},
)
}
// TODO implement next_entry_seed to avoid checking key_deserialized.

#[inline(always)]
fn size_hint(&self) -> Option<usize> {
Expand All @@ -451,6 +461,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> {
decoders,
input: self.input,
len,
key_deserialized: false, // No keys have been deserialized yet, so next_value_seed can't be called.
})
}

Expand Down Expand Up @@ -561,6 +572,8 @@ impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> {

#[cfg(test)]
mod tests {
use serde::de::MapAccess;
use serde::Deserializer;
use std::collections::BTreeMap;

#[test]
Expand Down Expand Up @@ -621,4 +634,35 @@ mod tests {
// Complex.
test!(vec![(None, 3), (Some(4), 5)], Vec<(Option<u8>, u8)>);
}

#[test]
#[should_panic = "next_value_seed before next_key_seed"]
fn map_incorrect_len_values() {
let mut map = BTreeMap::new();
map.insert(1u8, 2u8);
let input = crate::serialize(&map).unwrap();

let w = super::DecoderWrapper {
decoder: &mut super::SerdeDecoder::Unspecified { length: 1 },
input: &mut input.as_slice(),
};

struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = ();
fn expecting(&self, _: &mut std::fmt::Formatter) -> std::fmt::Result {
unreachable!()
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
assert_eq!(map.next_key::<u8>().unwrap().unwrap(), 1u8);
assert_eq!(map.next_value::<u8>().unwrap(), 2u8);
map.next_value::<u8>().unwrap();
Ok(())
}
}
w.deserialize_map(Visitor).unwrap();
}
}
98 changes: 91 additions & 7 deletions src/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ macro_rules! impl_ser {
impl<'a> Serializer for EncoderWrapper<'a> {
type Ok = ();
type Error = Error;
type SerializeSeq = EncoderWrapper<'a>;
type SerializeSeq = SeqSerializer<'a>;
type SerializeTuple = TupleSerializer<'a>;
type SerializeTupleStruct = TupleSerializer<'a>;
type SerializeTupleVariant = TupleSerializer<'a>;
Expand Down Expand Up @@ -357,9 +357,10 @@ impl<'a> Serializer for EncoderWrapper<'a> {
let b = specify!(self, Seq);
b.0.encode(&len);
b.1.reserve_fast(len);
Ok(Self {
Ok(SeqSerializer {
lazy: &mut b.1,
index_alloc: self.index_alloc,
len,
})
}

Expand Down Expand Up @@ -446,6 +447,8 @@ impl<'a> Serializer for EncoderWrapper<'a> {
Ok(MapSerializer {
encoders: &mut b.1,
index_alloc: self.index_alloc,
len,
key_serialized: false, // No keys have been serialized yet, so serialize_value can't be called.
})
}

Expand Down Expand Up @@ -481,11 +484,18 @@ macro_rules! ok_error_end {
};
}

impl SerializeSeq for EncoderWrapper<'_> {
struct SeqSerializer<'a> {
lazy: &'a mut LazyEncoder,
index_alloc: &'a mut usize,
len: usize,
}

impl SerializeSeq for SeqSerializer<'_> {
ok_error_end!();
// TODO(unsound): could be called more than len times by buggy safe code but we only reserved len.
#[inline(always)]
fn serialize_element<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
// Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len elements.
self.len = self.len.checked_sub(1).expect("length mismatch");
value.serialize(EncoderWrapper {
lazy: &mut *self.lazy,
index_alloc: &mut *self.index_alloc,
Expand Down Expand Up @@ -531,39 +541,50 @@ impl_tuple!(SerializeStructVariant, serialize_field, _key);
struct MapSerializer<'a> {
encoders: &'a mut (LazyEncoder, LazyEncoder), // (keys, values)
index_alloc: &'a mut usize,
len: usize,
key_serialized: bool,
}

impl SerializeMap for MapSerializer<'_> {
ok_error_end!();
// TODO(unsound): could be called more than len times by buggy safe code but we only reserved len.
#[inline(always)]
fn serialize_key<T: ?Sized>(&mut self, key: &T) -> Result<()>
where
T: Serialize,
{
// Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len keys/values.
self.len = self.len.checked_sub(1).expect("length mismatch");
// Safety: Make sure serialize_value is called at most once after each serialize_key.
self.key_serialized = true;
key.serialize(EncoderWrapper {
lazy: &mut self.encoders.0,
index_alloc: &mut *self.index_alloc,
})
}

// TODO(unsound): could be called more than len times by buggy safe code but we only reserved len.
#[inline(always)]
fn serialize_value<T: ?Sized>(&mut self, value: &T) -> Result<()>
where
T: Serialize,
{
// Safety: Make sure serialize_value is called at most once after each serialize_key.
assert!(
std::mem::take(&mut self.key_serialized),
"serialize_value before serialize_key"
);
value.serialize(EncoderWrapper {
lazy: &mut self.encoders.1,
index_alloc: &mut *self.index_alloc,
})
}
// TODO implement serialize_entry to avoid checking key_serialized.
}

#[cfg(test)]
mod tests {
use serde::ser::SerializeTuple;
use serde::ser::{SerializeMap, SerializeSeq, SerializeTuple};
use serde::{Serialize, Serializer};
use std::num::NonZeroUsize;

#[test]
fn enum_256_variants() {
Expand Down Expand Up @@ -613,4 +634,67 @@ mod tests {
}
let _ = crate::serialize(&vec![TupleN(1), TupleN(2)]);
}

// Has to be a macro because it borrows something on the stack and returns it.
macro_rules! new_wrapper {
() => {
super::EncoderWrapper {
lazy: &mut super::LazyEncoder::Unspecified {
reserved: NonZeroUsize::new(1),
},
index_alloc: &mut 0,
}
};
}

#[test]
fn seq_valid() {
let w = new_wrapper!();
let mut seq = w.serialize_seq(Some(1)).unwrap();
let _ = seq.serialize_element(&0u8); // serialize_seq 1 == serialize 1.
}

#[test]
#[should_panic = "length mismatch"]
fn seq_incorrect_len() {
let w = new_wrapper!();
let mut seq = w.serialize_seq(Some(1)).unwrap();
let _ = seq.serialize_element(&0u8); // serialize_seq 1 != serialize 2.
let _ = seq.serialize_element(&0u8);
}

#[test]
fn map_valid() {
let w = new_wrapper!();
let mut map = w.serialize_map(Some(1)).unwrap();
let _ = map.serialize_key(&0u8); // serialize_map 1 == (key, value).
let _ = map.serialize_value(&0u8);
}

#[test]
#[should_panic = "length mismatch"]
fn map_incorrect_len_keys() {
let w = new_wrapper!();
let mut map = w.serialize_map(Some(1)).unwrap();
let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, _) (key, _)
let _ = map.serialize_key(&0u8);
}

#[test]
#[should_panic = "serialize_value before serialize_key"]
fn map_value_before_key() {
let w = new_wrapper!();
let mut map = w.serialize_map(Some(1)).unwrap();
let _ = map.serialize_value(&0u8);
}

#[test]
#[should_panic = "serialize_value before serialize_key"]
fn map_incorrect_len_values() {
let w = new_wrapper!();
let mut map = w.serialize_map(Some(1)).unwrap();
let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, value) (_, value).
let _ = map.serialize_value(&0u8);
let _ = map.serialize_value(&0u8);
}
}

0 comments on commit 9279196

Please sign in to comment.