Skip to content

Commit

Permalink
Support rust_decimal.
Browse files Browse the repository at this point in the history
  • Loading branch information
finnbear committed Sep 20, 2024
1 parent 195318d commit a1a5cc0
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ arrayvec = { version = "0.7", default-features = false, optional = true }
bitcode_derive = { version = "0.6.3", path = "./bitcode_derive", optional = true }
bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] }
glam = { version = ">=0.21", default-features = false, optional = true }
rust_decimal = { version = "1.36", default-features = false, optional = true }
serde = { version = "1.0", default-features = false, features = [ "alloc" ], optional = true }

[dev-dependencies]
Expand All @@ -37,7 +38,7 @@ zstd = "0.13.0"

[features]
derive = [ "dep:bitcode_derive" ]
std = [ "serde?/std", "glam?/std", "arrayvec?/std" ]
std = [ "serde?/std", "glam?/std", "arrayvec?/std", "rust_decimal?/std" ]
default = [ "derive", "std" ]

[package.metadata.docs.rs]
Expand Down
3 changes: 2 additions & 1 deletion fuzz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ cargo-fuzz = true

[dependencies]
arrayvec = { version = "0.7", features = ["serde"] }
bitcode = { path = "..", features = [ "arrayvec", "serde" ] }
bitcode = { path = "..", features = [ "arrayvec", "serde", "rust_decimal" ] }
libfuzzer-sys = "0.4"
rust_decimal = "1.36.0"
serde = { version ="1.0", features = [ "derive" ] }

# Prevent this from interfering with workspaces
Expand Down
2 changes: 2 additions & 0 deletions fuzz/fuzz_targets/fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt::Debug;
use std::num::NonZeroU32;
use std::time::Duration;
use rust_decimal::Decimal;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddr, SocketAddrV6};

#[inline(never)]
Expand Down Expand Up @@ -209,6 +210,7 @@ fuzz_target!(|data: &[u8]| {
ArrayString<70>,
ArrayVec<u8, 5>,
ArrayVec<u8, 70>,
Decimal,
Duration,
Ipv4Addr,
Ipv6Addr,
Expand Down
2 changes: 2 additions & 0 deletions src/ext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod arrayvec;
#[cfg(feature = "glam")]
#[rustfmt::skip] // Makes impl_struct! calls way longer.
mod glam;
#[cfg(feature = "rust_decimal")]
mod rust_decimal;

#[allow(unused)]
macro_rules! impl_struct {
Expand Down
105 changes: 105 additions & 0 deletions src/ext/rust_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use crate::{
convert::{self, ConvertFrom},
Decode, Encode,
};
use bytemuck::CheckedBitPattern;
use rust_decimal::Decimal;

type DecimalConversion = (u32, u32, u32, Flags);

impl ConvertFrom<&Decimal> for DecimalConversion {
fn convert_from(value: &Decimal) -> Self {
let unpacked = value.unpack();
(
unpacked.lo,
unpacked.mid,
unpacked.hi,
Flags::new(unpacked.scale, unpacked.negative),
)
}
}

impl ConvertFrom<DecimalConversion> for Decimal {
fn convert_from(value: DecimalConversion) -> Self {
Self::from_parts(
value.0,
value.1,
value.2,
value.3.negative(),
value.3.scale(),
)
}
}

impl Encode for Decimal {
type Encoder = convert::ConvertIntoEncoder<DecimalConversion>;
}

impl<'a> Decode<'a> for Decimal {
type Decoder = convert::ConvertFromDecoder<'a, DecimalConversion>;
}

impl ConvertFrom<&Flags> for u8 {
fn convert_from(flags: &Flags) -> Self {
flags.0
}
}

impl Encode for Flags {
type Encoder = convert::ConvertIntoEncoder<u8>;
}

/// A u8 guaranteed to satisfy (flags >> 1) <= 28. Prevents Decimal::from_parts from misbehaving.
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct Flags(u8);

impl Flags {
#[inline(always)]
fn new(scale: u32, negative: bool) -> Self {
Self((scale as u8) << 1 | negative as u8)
}

#[inline(always)]
fn scale(&self) -> u32 {
(self.0 >> 1) as u32
}

#[inline(always)]
fn negative(&self) -> bool {
self.0 & 1 == 1
}
}

// Safety: u8 and Flags have the same layout since Flags is #[repr(transparent)].
unsafe impl CheckedBitPattern for Flags {
type Bits = u8;
#[inline(always)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
(*bits >> 1) <= 28
}
}

impl<'a> Decode<'a> for Flags {
type Decoder = crate::int::CheckedIntDecoder<'a, Flags, u8>;
}

#[cfg(test)]
mod tests {
use crate::{decode, encode};
use rust_decimal::Decimal;

#[test]
fn rust_decimal() {
let vs = [
Decimal::from(0),
Decimal::from(-1),
Decimal::from(1) / Decimal::from(2),
Decimal::from(1),
Decimal::from(999999999999999999u64),
];
for v in vs {
assert_eq!(decode::<Decimal>(&encode(&v)).unwrap(), v);
}
}
}
12 changes: 6 additions & 6 deletions src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ macro_rules! specify {
#[rustfmt::skip]
fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> {
let &mut SerdeDecoder::Unspecified { length } = decoder else {
type_changed!()
};
type_changed!()
};
*decoder = SerdeDecoder::$variant(Default::default());
decoder.populate(input, length)
}
Expand All @@ -130,10 +130,10 @@ macro_rules! specify {
}
#[rustfmt::skip]
let SerdeDecoder::$variant(d) = &mut *$self.decoder else {
// Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either
// errors or sets lazy to the correct decoder.
unsafe { core::hint::unreachable_unchecked() };
};
// Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either
// errors or sets lazy to the correct decoder.
unsafe { core::hint::unreachable_unchecked() };
};
d
}};
}
Expand Down

0 comments on commit a1a5cc0

Please sign in to comment.