diff --git a/Cargo.toml b/Cargo.toml index 6cf91e2..839307c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ arrayvec = { version = "0.7", default-features = false, optional = true } bitcode_derive = { version = "0.6.3", path = "./bitcode_derive", optional = true } bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } glam = { version = ">=0.21", default-features = false, optional = true } +rust_decimal = { version = "1.36.0", default-features = false, optional = true } serde = { version = "1.0", default-features = false, features = [ "alloc" ], optional = true } [dev-dependencies] @@ -37,8 +38,8 @@ zstd = "0.13.0" [features] derive = [ "dep:bitcode_derive" ] -std = [ "serde?/std", "glam?/std", "arrayvec?/std" ] -default = [ "derive", "std" ] +std = [ "serde?/std", "glam?/std", "arrayvec?/std", "rust_decimal?/std" ] +default = [ "derive", "std", "rust_decimal" ] [package.metadata.docs.rs] features = [ "derive", "serde", "std" ] @@ -48,4 +49,4 @@ features = [ "derive", "serde", "std" ] #lto = true [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } \ No newline at end of file +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 653b3c4..10b7d56 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -10,8 +10,9 @@ cargo-fuzz = true [dependencies] arrayvec = { version = "0.7", features = ["serde"] } -bitcode = { path = "..", features = [ "arrayvec", "serde" ] } +bitcode = { path = "..", features = [ "arrayvec", "serde", "rust_decimal" ] } libfuzzer-sys = "0.4" +rust_decimal = "1.36.0" serde = { version ="1.0", features = [ "derive" ] } # Prevent this from interfering with workspaces @@ -22,4 +23,4 @@ members = ["."] name = "fuzz" path = "fuzz_targets/fuzz.rs" test = false -doc = false \ No newline at end of file +doc = false diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 6f489de..190e7c8 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; use std::num::NonZeroU32; use std::time::Duration; +use rust_decimal::Decimal; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddr, SocketAddrV6}; #[inline(never)] @@ -209,6 +210,7 @@ fuzz_target!(|data: &[u8]| { ArrayString<70>, ArrayVec, ArrayVec, + Decimal, Duration, Ipv4Addr, Ipv6Addr, diff --git a/src/derive/mod.rs b/src/derive/mod.rs index fc4c6fc..6b114da 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use core::num::NonZeroUsize; mod array; -mod convert; +pub(crate) mod convert; mod duration; mod empty; mod impls; diff --git a/src/ext/mod.rs b/src/ext/mod.rs index b7b26f1..0c79a30 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -3,6 +3,8 @@ mod arrayvec; #[cfg(feature = "glam")] #[rustfmt::skip] // Makes impl_struct! calls way longer. mod glam; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; #[allow(unused)] macro_rules! impl_struct { diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs new file mode 100644 index 0000000..f4685b6 --- /dev/null +++ b/src/ext/rust_decimal.rs @@ -0,0 +1,84 @@ +use crate::{ + convert::{self, ConvertFrom}, + Decode, Encode, +}; +use bytemuck::CheckedBitPattern; +use rust_decimal::Decimal; + +type DecimalConversion = (u32, u32, u32, Flags); + +impl ConvertFrom<&Decimal> for DecimalConversion { + fn convert_from(value: &Decimal) -> Self { + let unpacked = value.unpack(); + ( + unpacked.lo, + unpacked.mid, + unpacked.hi, + Flags::new(unpacked.scale, unpacked.negative), + ) + } +} + +impl ConvertFrom for Decimal { + fn convert_from(value: DecimalConversion) -> Self { + Self::from_parts( + value.0, + value.1, + value.2, + value.3.negative(), + value.3.scale(), + ) + } +} + +impl Encode for Decimal { + type Encoder = convert::ConvertIntoEncoder; +} +impl<'a> Decode<'a> for Decimal { + type Decoder = convert::ConvertFromDecoder<'a, DecimalConversion>; +} + +impl ConvertFrom<&Flags> for u8 { + fn convert_from(flags: &Flags) -> Self { + flags.0 + } +} + +impl Encode for Flags { + type Encoder = convert::ConvertIntoEncoder; +} + +/// A u8 guaranteed to satisfy (flags >> 1) <= 28. Prevents Decimal::from_parts from misbehaving. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Flags(u8); + +impl Flags { + #[inline(always)] + fn new(scale: u32, negative: bool) -> Self { + Self((scale as u8) << 1 | negative as u8) + } + + #[inline(always)] + fn scale(&self) -> u32 { + (self.0 >> 1) as u32 + } + + #[inline(always)] + fn negative(&self) -> bool { + self.0 & 1 == 1 + } +} + +// Safety: u8 and Flags have the same layout since Flags is #[repr(transparent)]. +unsafe impl CheckedBitPattern for Flags { + type Bits = u8; + #[inline(always)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + (*bits >> 1) <= 28 + } +} + +impl<'a> Decode<'a> for Flags { + type Decoder = crate::int::CheckedIntDecoder<'a, Flags, u8>; +} diff --git a/src/serde/de.rs b/src/serde/de.rs index 9252dcd..8779083 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -119,8 +119,8 @@ macro_rules! specify { #[cold] fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> { let &mut SerdeDecoder::Unspecified { length } = decoder else { - type_changed!() - }; + type_changed!() + }; *decoder = SerdeDecoder::$variant(Default::default()); decoder.populate(input, length) } @@ -128,10 +128,10 @@ macro_rules! specify { } } let SerdeDecoder::$variant(d) = &mut *$self.decoder else { - // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either - // errors or sets lazy to the correct decoder. - unsafe { core::hint::unreachable_unchecked() }; - }; + // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either + // errors or sets lazy to the correct decoder. + unsafe { core::hint::unreachable_unchecked() }; + }; d }}; }