From a1a5cc09aed7f8c38b8ffd875ae9ebdb123af511 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Fri, 20 Sep 2024 11:32:46 -0700 Subject: [PATCH 1/9] Support rust_decimal. --- Cargo.toml | 3 +- fuzz/Cargo.toml | 3 +- fuzz/fuzz_targets/fuzz.rs | 2 + src/ext/mod.rs | 2 + src/ext/rust_decimal.rs | 105 ++++++++++++++++++++++++++++++++++++++ src/serde/de.rs | 12 ++--- 6 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 src/ext/rust_decimal.rs diff --git a/Cargo.toml b/Cargo.toml index 6cf91e2..152246f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ arrayvec = { version = "0.7", default-features = false, optional = true } bitcode_derive = { version = "0.6.3", path = "./bitcode_derive", optional = true } bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } glam = { version = ">=0.21", default-features = false, optional = true } +rust_decimal = { version = "1.36", default-features = false, optional = true } serde = { version = "1.0", default-features = false, features = [ "alloc" ], optional = true } [dev-dependencies] @@ -37,7 +38,7 @@ zstd = "0.13.0" [features] derive = [ "dep:bitcode_derive" ] -std = [ "serde?/std", "glam?/std", "arrayvec?/std" ] +std = [ "serde?/std", "glam?/std", "arrayvec?/std", "rust_decimal?/std" ] default = [ "derive", "std" ] [package.metadata.docs.rs] diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 653b3c4..4d189cf 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -10,8 +10,9 @@ cargo-fuzz = true [dependencies] arrayvec = { version = "0.7", features = ["serde"] } -bitcode = { path = "..", features = [ "arrayvec", "serde" ] } +bitcode = { path = "..", features = [ "arrayvec", "serde", "rust_decimal" ] } libfuzzer-sys = "0.4" +rust_decimal = "1.36.0" serde = { version ="1.0", features = [ "derive" ] } # Prevent this from interfering with workspaces diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 6f489de..190e7c8 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; use std::num::NonZeroU32; use std::time::Duration; +use rust_decimal::Decimal; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddr, SocketAddrV6}; #[inline(never)] @@ -209,6 +210,7 @@ fuzz_target!(|data: &[u8]| { ArrayString<70>, ArrayVec, ArrayVec, + Decimal, Duration, Ipv4Addr, Ipv6Addr, diff --git a/src/ext/mod.rs b/src/ext/mod.rs index b7b26f1..0c79a30 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -3,6 +3,8 @@ mod arrayvec; #[cfg(feature = "glam")] #[rustfmt::skip] // Makes impl_struct! calls way longer. mod glam; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; #[allow(unused)] macro_rules! impl_struct { diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs new file mode 100644 index 0000000..58bc0ca --- /dev/null +++ b/src/ext/rust_decimal.rs @@ -0,0 +1,105 @@ +use crate::{ + convert::{self, ConvertFrom}, + Decode, Encode, +}; +use bytemuck::CheckedBitPattern; +use rust_decimal::Decimal; + +type DecimalConversion = (u32, u32, u32, Flags); + +impl ConvertFrom<&Decimal> for DecimalConversion { + fn convert_from(value: &Decimal) -> Self { + let unpacked = value.unpack(); + ( + unpacked.lo, + unpacked.mid, + unpacked.hi, + Flags::new(unpacked.scale, unpacked.negative), + ) + } +} + +impl ConvertFrom for Decimal { + fn convert_from(value: DecimalConversion) -> Self { + Self::from_parts( + value.0, + value.1, + value.2, + value.3.negative(), + value.3.scale(), + ) + } +} + +impl Encode for Decimal { + type Encoder = convert::ConvertIntoEncoder; +} + +impl<'a> Decode<'a> for Decimal { + type Decoder = convert::ConvertFromDecoder<'a, DecimalConversion>; +} + +impl ConvertFrom<&Flags> for u8 { + fn convert_from(flags: &Flags) -> Self { + flags.0 + } +} + +impl Encode for Flags { + type Encoder = convert::ConvertIntoEncoder; +} + +/// A u8 guaranteed to satisfy (flags >> 1) <= 28. Prevents Decimal::from_parts from misbehaving. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct Flags(u8); + +impl Flags { + #[inline(always)] + fn new(scale: u32, negative: bool) -> Self { + Self((scale as u8) << 1 | negative as u8) + } + + #[inline(always)] + fn scale(&self) -> u32 { + (self.0 >> 1) as u32 + } + + #[inline(always)] + fn negative(&self) -> bool { + self.0 & 1 == 1 + } +} + +// Safety: u8 and Flags have the same layout since Flags is #[repr(transparent)]. +unsafe impl CheckedBitPattern for Flags { + type Bits = u8; + #[inline(always)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + (*bits >> 1) <= 28 + } +} + +impl<'a> Decode<'a> for Flags { + type Decoder = crate::int::CheckedIntDecoder<'a, Flags, u8>; +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + use rust_decimal::Decimal; + + #[test] + fn rust_decimal() { + let vs = [ + Decimal::from(0), + Decimal::from(-1), + Decimal::from(1) / Decimal::from(2), + Decimal::from(1), + Decimal::from(999999999999999999u64), + ]; + for v in vs { + assert_eq!(decode::(&encode(&v)).unwrap(), v); + } + } +} diff --git a/src/serde/de.rs b/src/serde/de.rs index da09a81..d26265a 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -120,8 +120,8 @@ macro_rules! specify { #[rustfmt::skip] fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8]) -> Result<()> { let &mut SerdeDecoder::Unspecified { length } = decoder else { - type_changed!() - }; + type_changed!() + }; *decoder = SerdeDecoder::$variant(Default::default()); decoder.populate(input, length) } @@ -130,10 +130,10 @@ macro_rules! specify { } #[rustfmt::skip] let SerdeDecoder::$variant(d) = &mut *$self.decoder else { - // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either - // errors or sets lazy to the correct decoder. - unsafe { core::hint::unreachable_unchecked() }; - }; + // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either + // errors or sets lazy to the correct decoder. + unsafe { core::hint::unreachable_unchecked() }; + }; d }}; } From 424032fe592dd81664fde18ce552b7ebfc16caa6 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Fri, 20 Sep 2024 11:37:05 -0700 Subject: [PATCH 2/9] Use macro. --- src/ext/rust_decimal.rs | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs index 58bc0ca..492b6ed 100644 --- a/src/ext/rust_decimal.rs +++ b/src/ext/rust_decimal.rs @@ -1,10 +1,9 @@ use crate::{ - convert::{self, ConvertFrom}, + convert::{self, impl_convert, ConvertFrom}, Decode, Encode, }; use bytemuck::CheckedBitPattern; use rust_decimal::Decimal; - type DecimalConversion = (u32, u32, u32, Flags); impl ConvertFrom<&Decimal> for DecimalConversion { @@ -31,13 +30,7 @@ impl ConvertFrom for Decimal { } } -impl Encode for Decimal { - type Encoder = convert::ConvertIntoEncoder; -} - -impl<'a> Decode<'a> for Decimal { - type Decoder = convert::ConvertFromDecoder<'a, DecimalConversion>; -} +impl_convert!(Decimal, DecimalConversion); impl ConvertFrom<&Flags> for u8 { fn convert_from(flags: &Flags) -> Self { From b0aa34c6b92798fa91841f2445a4b65289337b7d Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:12:12 -0700 Subject: [PATCH 3/9] Round-trip -0. --- src/ext/rust_decimal.rs | 13 +++++++++---- src/serde/variant.rs | 8 ++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs index 492b6ed..1c49cf7 100644 --- a/src/ext/rust_decimal.rs +++ b/src/ext/rust_decimal.rs @@ -20,13 +20,15 @@ impl ConvertFrom<&Decimal> for DecimalConversion { impl ConvertFrom for Decimal { fn convert_from(value: DecimalConversion) -> Self { - Self::from_parts( + let mut ret = Self::from_parts( value.0, value.1, value.2, - value.3.negative(), + false, value.3.scale(), - ) + ); + ret.set_sign_negative(value.3.negative()); + ret } } @@ -86,13 +88,16 @@ mod tests { fn rust_decimal() { let vs = [ Decimal::from(0), + Decimal::from_f64_retain(-0f64).unwrap(), Decimal::from(-1), Decimal::from(1) / Decimal::from(2), Decimal::from(1), Decimal::from(999999999999999999u64), ]; for v in vs { - assert_eq!(decode::(&encode(&v)).unwrap(), v); + let d = decode::(&encode(&v)).unwrap(); + assert_eq!(d, v); + assert_eq!(d.is_sign_negative(), v.is_sign_negative()); } } } diff --git a/src/serde/variant.rs b/src/serde/variant.rs index 2d96a5f..938e4bf 100644 --- a/src/serde/variant.rs +++ b/src/serde/variant.rs @@ -6,13 +6,13 @@ use core::marker::PhantomData; use core::num::NonZeroUsize; #[derive(Default)] -pub struct VariantEncoder { - data: VecImpl, +pub struct VariantEncoder { + data: VecImpl, } -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] - fn encode(&mut self, v: &u8) { + fn encode(&mut self, v: &Index) { unsafe { self.data.push_unchecked(*v) }; } } From ee250548f80f27c03a71b6c9937e630344bee01a Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:16:03 -0700 Subject: [PATCH 4/9] Test retention of trailing zero. --- src/ext/rust_decimal.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs index 1c49cf7..19467fb 100644 --- a/src/ext/rust_decimal.rs +++ b/src/ext/rust_decimal.rs @@ -83,6 +83,7 @@ impl<'a> Decode<'a> for Flags { mod tests { use crate::{decode, encode}; use rust_decimal::Decimal; + use std::str::FromStr; #[test] fn rust_decimal() { @@ -93,11 +94,13 @@ mod tests { Decimal::from(1) / Decimal::from(2), Decimal::from(1), Decimal::from(999999999999999999u64), + Decimal::from_str("3.100").unwrap() ]; for v in vs { let d = decode::(&encode(&v)).unwrap(); assert_eq!(d, v); assert_eq!(d.is_sign_negative(), v.is_sign_negative()); + assert_eq!(d.scale(), v.scale()); } } } From 2c6dffad9a8f3e019faa53d2ce18ada8a88a5bc6 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:16:33 -0700 Subject: [PATCH 5/9] fmt. --- src/ext/rust_decimal.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs index 19467fb..ac07a25 100644 --- a/src/ext/rust_decimal.rs +++ b/src/ext/rust_decimal.rs @@ -20,13 +20,7 @@ impl ConvertFrom<&Decimal> for DecimalConversion { impl ConvertFrom for Decimal { fn convert_from(value: DecimalConversion) -> Self { - let mut ret = Self::from_parts( - value.0, - value.1, - value.2, - false, - value.3.scale(), - ); + let mut ret = Self::from_parts(value.0, value.1, value.2, false, value.3.scale()); ret.set_sign_negative(value.3.negative()); ret } @@ -94,7 +88,7 @@ mod tests { Decimal::from(1) / Decimal::from(2), Decimal::from(1), Decimal::from(999999999999999999u64), - Decimal::from_str("3.100").unwrap() + Decimal::from_str("3.100").unwrap(), ]; for v in vs { let d = decode::(&encode(&v)).unwrap(); From 6ea4b0b98249d83114972a7e62c1080610b7d54f Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:18:30 -0700 Subject: [PATCH 6/9] Undo mistake. --- src/serde/variant.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serde/variant.rs b/src/serde/variant.rs index 938e4bf..94d33d2 100644 --- a/src/serde/variant.rs +++ b/src/serde/variant.rs @@ -6,13 +6,13 @@ use core::marker::PhantomData; use core::num::NonZeroUsize; #[derive(Default)] -pub struct VariantEncoder { - data: VecImpl, +pub struct VariantEncoder { + data: VecImpl, } -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] - fn encode(&mut self, v: &Index) { + fn encode(&mut self, v: &u8) { unsafe { self.data.push_unchecked(*v) }; } } From b3a3804a4b9d593aa71d19bff00c956d38b821e1 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:18:51 -0700 Subject: [PATCH 7/9] Undo mistake 2. --- src/serde/variant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serde/variant.rs b/src/serde/variant.rs index 94d33d2..f240305 100644 --- a/src/serde/variant.rs +++ b/src/serde/variant.rs @@ -10,7 +10,7 @@ pub struct VariantEncoder { data: VecImpl, } -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] fn encode(&mut self, v: &u8) { unsafe { self.data.push_unchecked(*v) }; From cbadccfe518fc40b080ec2fa851b389f7c677934 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:19:14 -0700 Subject: [PATCH 8/9] Undo mistake 3. --- src/serde/variant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serde/variant.rs b/src/serde/variant.rs index f240305..2d96a5f 100644 --- a/src/serde/variant.rs +++ b/src/serde/variant.rs @@ -10,7 +10,7 @@ pub struct VariantEncoder { data: VecImpl, } -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] fn encode(&mut self, v: &u8) { unsafe { self.data.push_unchecked(*v) }; From 8a598b30b8e013eb2eadd0c563ba46e4697326e7 Mon Sep 17 00:00:00 2001 From: Finn Bear Date: Mon, 23 Sep 2024 15:39:15 -0700 Subject: [PATCH 9/9] Prepare optimization. --- src/ext/rust_decimal.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/ext/rust_decimal.rs b/src/ext/rust_decimal.rs index ac07a25..2f9e4f1 100644 --- a/src/ext/rust_decimal.rs +++ b/src/ext/rust_decimal.rs @@ -20,7 +20,16 @@ impl ConvertFrom<&Decimal> for DecimalConversion { impl ConvertFrom for Decimal { fn convert_from(value: DecimalConversion) -> Self { - let mut ret = Self::from_parts(value.0, value.1, value.2, false, value.3.scale()); + let scale = value.3.scale(); + // Should make Decimal::from_parts faster, once it can be inlined, + // since it can skip division. + // Safety: impl CheckedBitPattern for Flags guarantees this. + unsafe { + if scale > 28 { + core::hint::unreachable_unchecked(); + } + } + let mut ret = Self::from_parts(value.0, value.1, value.2, false, scale); ret.set_sign_negative(value.3.negative()); ret }