diff --git a/Cargo.toml b/Cargo.toml index 6cf91e2..152246f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ arrayvec = { version = "0.7", default-features = false, optional = true } bitcode_derive = { version = "0.6.3", path = "./bitcode_derive", optional = true } bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } glam = { version = ">=0.21", default-features = false, optional = true } +rust_decimal = { version = "1.36", default-features = false, optional = true } serde = { version = "1.0", default-features = false, features = [ "alloc" ], optional = true } [dev-dependencies] @@ -37,7 +38,7 @@ zstd = "0.13.0" [features] derive = [ "dep:bitcode_derive" ] -std = [ "serde?/std", "glam?/std", "arrayvec?/std" ] +std = [ "serde?/std", "glam?/std", "arrayvec?/std", "rust_decimal?/std" ] default = [ "derive", "std" ] [package.metadata.docs.rs] diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 653b3c4..4d189cf 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -10,8 +10,9 @@ cargo-fuzz = true [dependencies] arrayvec = { version = "0.7", features = ["serde"] } -bitcode = { path = "..", features = [ "arrayvec", "serde" ] } +bitcode = { path = "..", features = [ "arrayvec", "serde", "rust_decimal" ] } libfuzzer-sys = "0.4" +rust_decimal = "1.36.0" serde = { version ="1.0", features = [ "derive" ] } # Prevent this from interfering with workspaces diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 6f489de..190e7c8 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; use std::num::NonZeroU32; use std::time::Duration; +use rust_decimal::Decimal; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddr, SocketAddrV6}; #[inline(never)] @@ -209,6 +210,7 @@ fuzz_target!(|data: &[u8]| { ArrayString<70>, ArrayVec, ArrayVec, + Decimal, Duration, Ipv4Addr, Ipv6Addr, diff --git a/src/ext/mod.rs b/src/ext/mod.rs index b7b26f1..0c79a30 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -3,6 +3,8 @@ mod arrayvec; #[cfg(feature = "glam")] #[rustfmt::skip] // Makes impl_struct! calls way longer. mod glam; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; #[allow(unused)] macro_rules! impl_struct { diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs new file mode 100644 index 0000000..2f9e4f1 --- /dev/null +++ b/src/ext/rust_decimal.rs @@ -0,0 +1,109 @@ +use crate::{ + convert::{self, impl_convert, ConvertFrom}, + Decode, Encode, +}; +use bytemuck::CheckedBitPattern; +use rust_decimal::Decimal; +type DecimalConversion = (u32, u32, u32, Flags); + +impl ConvertFrom<&Decimal> for DecimalConversion { + fn convert_from(value: &Decimal) -> Self { + let unpacked = value.unpack(); + ( + unpacked.lo, + unpacked.mid, + unpacked.hi, + Flags::new(unpacked.scale, unpacked.negative), + ) + } +} + +impl ConvertFrom for Decimal { + fn convert_from(value: DecimalConversion) -> Self { + let scale = value.3.scale(); + // Should make Decimal::from_parts faster, once it can be inlined, + // since it can skip division. + // Safety: impl CheckedBitPattern for Flags guarantees this. + unsafe { + if scale > 28 { + core::hint::unreachable_unchecked(); + } + } + let mut ret = Self::from_parts(value.0, value.1, value.2, false, scale); + ret.set_sign_negative(value.3.negative()); + ret + } +} + +impl_convert!(Decimal, DecimalConversion); + +impl ConvertFrom<&Flags> for u8 { + fn convert_from(flags: &Flags) -> Self { + flags.0 + } +} + +impl Encode for Flags { + type Encoder = convert::ConvertIntoEncoder; +} + +/// A u8 guaranteed to satisfy (flags >> 1) <= 28. Prevents Decimal::from_parts from misbehaving. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Flags(u8); + +impl Flags { + #[inline(always)] + fn new(scale: u32, negative: bool) -> Self { + Self((scale as u8) << 1 | negative as u8) + } + + #[inline(always)] + fn scale(&self) -> u32 { + (self.0 >> 1) as u32 + } + + #[inline(always)] + fn negative(&self) -> bool { + self.0 & 1 == 1 + } +} + +// Safety: u8 and Flags have the same layout since Flags is #[repr(transparent)]. +unsafe impl CheckedBitPattern for Flags { + type Bits = u8; + #[inline(always)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + (*bits >> 1) <= 28 + } +} + +impl<'a> Decode<'a> for Flags { + type Decoder = crate::int::CheckedIntDecoder<'a, Flags, u8>; +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + use rust_decimal::Decimal; + use std::str::FromStr; + + #[test] + fn rust_decimal() { + let vs = [ + Decimal::from(0), + Decimal::from_f64_retain(-0f64).unwrap(), + Decimal::from(-1), + Decimal::from(1) / Decimal::from(2), + Decimal::from(1), + Decimal::from(999999999999999999u64), + Decimal::from_str("3.100").unwrap(), + ]; + for v in vs { + let d = decode::(&encode(&v)).unwrap(); + assert_eq!(d, v); + assert_eq!(d.is_sign_negative(), v.is_sign_negative()); + assert_eq!(d.scale(), v.scale()); + } + } +} diff --git a/src/serde/de.rs b/src/serde/de.rs index da09a81..d26265a 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -120,8 +120,8 @@ macro_rules! specify { #[rustfmt::skip] fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> { let &mut SerdeDecoder::Unspecified { length } = decoder else { - type_changed!() - }; + type_changed!() + }; *decoder = SerdeDecoder::$variant(Default::default()); decoder.populate(input, length) } @@ -130,10 +130,10 @@ macro_rules! specify { } #[rustfmt::skip] let SerdeDecoder::$variant(d) = &mut *$self.decoder else { - // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either - // errors or sets lazy to the correct decoder. - unsafe { core::hint::unreachable_unchecked() }; - }; + // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either + // errors or sets lazy to the correct decoder. + unsafe { core::hint::unreachable_unchecked() }; + }; d }}; }