From 99b7b678af0ac909a50399c639e2f97770f29383 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 18:09:05 -0700 Subject: [PATCH] Handle errors instead of panicking in derive macro. --- bitcode_derive/src/attribute.rs | 5 +- bitcode_derive/src/bound.rs | 7 +- bitcode_derive/src/decode.rs | 275 ++++++++++++-------------------- bitcode_derive/src/encode.rs | 254 ++++++++++------------------- bitcode_derive/src/lib.rs | 37 ++--- bitcode_derive/src/shared.rs | 157 +++++++++++++++++- src/derive/mod.rs | 1 + 7 files changed, 362 insertions(+), 374 deletions(-) diff --git a/bitcode_derive/src/attribute.rs b/bitcode_derive/src/attribute.rs index a9e944d..31980f1 100644 --- a/bitcode_derive/src/attribute.rs +++ b/bitcode_derive/src/attribute.rs @@ -42,12 +42,12 @@ impl BitcodeAttr { return err(nested, "duplicate"); } *b = Some(bound_type); + Ok(()) } else { - return err(nested, "can only apply bound to fields"); + err(nested, "can only apply bound to fields") } } } - Ok(()) } } @@ -81,7 +81,6 @@ impl BitcodeAttrs { Ok(ret) } - #[allow(unused)] // TODO pub fn parse_variant(attrs: &[Attribute], _derive_attrs: &Self) -> Result { let mut ret = Self::new(AttrType::Variant); ret.parse_inner(attrs)?; diff --git a/bitcode_derive/src/bound.rs b/bitcode_derive/src/bound.rs index 232a742..99f6d83 100644 --- a/bitcode_derive/src/bound.rs +++ b/bitcode_derive/src/bound.rs @@ -23,15 +23,16 @@ impl FieldBounds { } } - pub fn apply_to_generics(self, generics: &mut syn::Generics) { + pub fn added_to(self, mut generics: syn::Generics) -> syn::Generics { for (bound, (fields, extra_bound_types)) in self.bounds { - *generics = with_bound(&fields, extra_bound_types, generics, &bound); + generics = with_bound(&fields, extra_bound_types, &generics, &bound); } + generics } } // Based on https://github.com/serde-rs/serde/blob/0c6a2bbf794abe966a4763f5b7ff23acb535eb7f/serde_derive/src/bound.rs#L94-L314 -pub fn with_bound( +fn with_bound( fields: &[syn::Field], extra_bound_types: Vec, generics: &syn::Generics, diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 622e2fd..ab4b765 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,14 +1,10 @@ -use crate::attribute::BitcodeAttrs; -use crate::bound::FieldBounds; -use crate::shared::{ - destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, -}; -use crate::{err, private}; +use crate::private; +use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::{ - parse_quote, Data, DeriveInput, Fields, GenericParam, Lifetime, LifetimeParam, Path, - PredicateLifetime, Result, Type, WherePredicate, + parse_quote, GenericParam, Generics, Lifetime, LifetimeParam, Path, PredicateLifetime, Type, + WherePredicate, }; const DE_LIFETIME: &str = "__de"; @@ -17,8 +13,7 @@ fn de_lifetime() -> Lifetime { } #[derive(Copy, Clone)] -#[repr(u8)] -enum Item { +pub enum Item { Type, Default, Populate, @@ -35,7 +30,9 @@ impl Item { Self::DecodeInPlace, ]; const COUNT: usize = Self::ALL.len(); +} +impl crate::shared::Item for Item { fn field_impl( self, field_name: TokenStream, @@ -89,11 +86,11 @@ impl Item { } } - pub fn variant_impls( + fn enum_impl( self, variant_count: usize, - mut pattern: impl FnMut(usize) -> TokenStream, - mut inner: impl FnMut(Self, usize) -> TokenStream, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { // if variant_count is 0 or 1 variants don't have to be decoded. let decode_variants = variant_count > 1; @@ -146,9 +143,7 @@ impl Item { if inner.is_empty() { quote! {} } else { - let i: u8 = i - .try_into() - .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + let i = variant_index(i); let length = decode_variants .then(|| { quote! { @@ -176,7 +171,7 @@ impl Item { unsafe { std::hint::unreachable_unchecked() } }; } - let mut pattern = |i: usize| { + let pattern = |i: usize| { let pattern = pattern(i); matches!(self, Self::DecodeInPlace) .then(|| { @@ -194,7 +189,7 @@ impl Item { .map(|i| { let inner = inner(item, i); let pattern = pattern(i); - let i: u8 = i.try_into().unwrap(); // Already checked in reserve impl. + let i = variant_index(i); quote! { #i => { #inner @@ -225,184 +220,110 @@ impl Item { } } } +} - // TODO dedup with encode.rs - fn field_impls( - self, - global_prefix: Option<&str>, - fields: &Fields, - parent_attrs: &BitcodeAttrs, - bounds: &mut FieldBounds, - ) -> Result { - fields - .iter() - .enumerate() - .map(move |(i, field)| { - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - - let name = field_name(i, field, false); - let real_name = field_name(i, field, true); +pub struct Decode; +impl crate::shared::Derive<{ Item::COUNT }> for Decode { + type Item = Item; + const ALL: [Self::Item; Item::COUNT] = Item::ALL; - let global_name = global_prefix - .map(|global_prefix| { - let ident = - Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); - quote! { #ident } - }) - .unwrap_or_else(|| name.clone()); + fn bound(&self) -> Path { + let private = private(); + let de = de_lifetime(); + parse_quote!(#private::Decode<#de>) + } - let field_impl = self.field_impl(name, global_name, real_name, &field.ty); + fn derive_impl( + &self, + output: [TokenStream; Item::COUNT], + ident: Ident, + mut generics: Generics, + ) -> TokenStream { + let input_generics = generics.clone(); + let (_, input_generics, _) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; - let private = private(); - let de = de_lifetime(); - let bound: Path = parse_quote!(#private::Decode<#de>); - bounds.add_bound_type(field.clone(), &field_attrs, bound); - Ok(field_impl) - }) - .collect() - } -} + // Add 'de lifetime after isolating input_generics. + let de = de_lifetime(); + let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime { + lifetime: de.clone(), + colon_token: parse_quote!(:), + bounds: generics + .params + .iter() + .filter_map(|p| { + if let GenericParam::Lifetime(p) = p { + Some(p.lifetime.clone()) + } else { + None + } + }) + .collect(), + }); -struct Output([TokenStream; Item::COUNT]); + // Push de_param after bounding 'de: 'a. + let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); + generics.params.push(de_param.clone()); + generics + .make_where_clause() + .predicates + .push(de_where_predicate); -impl Output { - fn haunt(mut self) -> Self { - let type_ = &mut self.0[Item::Type as usize]; - if type_.is_empty() { - let de = de_lifetime(); - *type_ = quote! { __spooky: std::marker::PhantomData<&#de ()>, }; - } - let default = &mut self.0[Item::Default as usize]; - if default.is_empty() { - *default = quote! { __spooky: Default::default(), }; - } - self - } -} + let combined_generics = generics.clone(); + let (impl_generics, _, where_clause) = combined_generics.split_for_impl(); -pub fn derive_impl(mut input: DeriveInput) -> Result { - let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; - let mut generics = input.generics; - let mut bounds = FieldBounds::default(); + // Decoder can't contain any lifetimes from input (which would limit reuse of decoder). + remove_lifetimes(&mut generics); + generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. + let (decoder_impl_generics, decoder_generics, decoder_where_clause) = + generics.split_for_impl(); - let ident = input.ident; - syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); - let output = (match input.data { - Data::Struct(data_struct) => { - let destructure_fields = &destructure_fields(&data_struct.fields); - Output(Item::ALL.map(|item| { - let field_impls = item - .field_impls(None, &data_struct.fields, &attrs, &mut bounds) - .unwrap(); // TODO don't unwrap - item.struct_impl(&ident, destructure_fields, &field_impls) - })) + let [mut type_body, mut default_body, populate_body, decode_in_place_body] = output; + if type_body.is_empty() { + type_body = quote! { __spooky: std::marker::PhantomData<&#de ()>, }; } - Data::Enum(data_enum) => { - let variant_count = data_enum.variants.len(); - Output(Item::ALL.map(|item| { - item.variant_impls( - variant_count, - |i| { - let variant = &data_enum.variants[i]; - let variant_name = &variant.ident; - let destructure_fields = destructure_fields(&variant.fields); - quote! { - #ident::#variant_name #destructure_fields - } - }, - |item, i| { - let variant = &data_enum.variants[i]; - let global_prefix = format!("{}_", &variant.ident); - let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. - item.field_impls(Some(&global_prefix), &variant.fields, &attrs, &mut bounds) - .unwrap() // TODO don't unwrap. - }, - ) - })) + if default_body.is_empty() { + default_body = quote! { __spooky: Default::default(), }; } - Data::Union(u) => err(&u.union_token, "unions are not supported")?, - }) - .haunt(); - bounds.apply_to_generics(&mut generics); - let input_generics = generics.clone(); - let (_, input_generics, _) = input_generics.split_for_impl(); - let input_ty = quote! { #ident #input_generics }; + let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); + let decoder_ty = quote! { #decoder_ident #decoder_generics }; + let private = private(); - // Add 'de lifetime after isolating input_generics. - let de = de_lifetime(); - let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime { - lifetime: de.clone(), - colon_token: parse_quote!(:), - bounds: generics - .params - .iter() - .filter_map(|p| { - if let GenericParam::Lifetime(p) = p { - Some(p.lifetime.clone()) - } else { - None + quote! { + const _: () = { + impl #impl_generics #private::Decode<#de> for #input_ty #where_clause { + type Decoder = #decoder_ty; } - }) - .collect(), - }); - - // Push de_param after bounding 'de: 'a. - let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); - generics.params.push(de_param.clone()); - generics - .make_where_clause() - .predicates - .push(de_where_predicate); - - let combined_generics = generics.clone(); - let (impl_generics, _, where_clause) = combined_generics.split_for_impl(); - - // Decoder can't contain any lifetimes from input (which would limit reuse of decoder). - remove_lifetimes(&mut generics); - generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. - let (decoder_impl_generics, decoder_generics, decoder_where_clause) = generics.split_for_impl(); - - let Output([type_body, default_body, populate_body, decode_in_place_body]) = output; - let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); - let decoder_ty = quote! { #decoder_ident #decoder_generics }; - let private = private(); - let ret = quote! { - const _: () = { - impl #impl_generics #private::Decode<#de> for #input_ty #where_clause { - type Decoder = #decoder_ty; - } - - #[allow(non_snake_case)] - pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause { - #type_body - } + #[allow(non_snake_case)] + pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause { + #type_body + } - // Avoids bounding #impl_generics: Default. - impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause { - fn default() -> Self { - Self { - #default_body + // Avoids bounding #impl_generics: Default. + impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause { + fn default() -> Self { + Self { + #default_body + } } } - } - impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause { - fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> { - #populate_body - Ok(()) + impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause { + fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> { + #populate_body + Ok(()) + } } - } - impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { - #[cfg_attr(not(debug_assertions), inline(always))] - fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { - #decode_in_place_body + impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { + #decode_in_place_body + } } - } - }; - }; - // panic!("{ret}"); - Ok(ret) + }; + } + } } diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 9338bdc..73b312f 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,15 +1,11 @@ -use crate::attribute::BitcodeAttrs; -use crate::bound::FieldBounds; -use crate::shared::{ - destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, -}; -use crate::{err, private}; +use crate::private; +use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{parse_quote, Data, DeriveInput, Fields, Path, Result, Type}; +use syn::{parse_quote, Generics, Path, Type}; #[derive(Copy, Clone)] -enum Item { +pub enum Item { Type, Default, Encode, @@ -17,7 +13,6 @@ enum Item { CollectInto, Reserve, } - impl Item { const ALL: [Self; 6] = [ Self::Type, @@ -28,7 +23,8 @@ impl Item { Self::Reserve, ]; const COUNT: usize = Self::ALL.len(); - +} +impl crate::shared::Item for Item { fn field_impl( self, field_name: TokenStream, @@ -101,11 +97,11 @@ impl Item { } } - pub fn variant_impls( + fn enum_impl( self, variant_count: usize, - mut pattern: impl FnMut(usize) -> TokenStream, - mut inner: impl FnMut(Self, usize) -> TokenStream, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { // if variant_count is 0 or 1 variants don't have to be encoded. let encode_variants = variant_count > 1; @@ -139,9 +135,7 @@ impl Item { let variants: TokenStream = (0..variant_count) .map(|i| { let pattern = pattern(i); - let i: u8 = i - .try_into() - .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + let i = variant_index(i); quote! { #pattern => #i, } @@ -157,8 +151,8 @@ impl Item { .unwrap_or_default(); let inners: TokenStream = (0..variant_count) .map(|i| { - // We don't know the exact number of this variant since there are more than one so we have to - // reserve one at a time. + // We don't know the exact number of this variant since there is more than + // one, so we have to reserve one at a time. let reserve = encode_variants .then(|| { let reserve = inner(Self::Reserve, i); @@ -189,7 +183,13 @@ impl Item { }) .unwrap_or_default() } - Self::EncodeVectored => unimplemented!(), // TODO encode enum vectored. + // This is a copy of Encode::encode_vectored's default impl (which provides no speedup). + // TODO optimize enum encode_vectored. + Self::EncodeVectored => quote! { + for t in i { + self.encode(t); + } + }, Self::CollectInto => { let variants = encode_variants .then(|| { @@ -217,170 +217,86 @@ impl Item { } } } - - fn field_impls( - self, - global_prefix: Option<&str>, - fields: &Fields, - parent_attrs: &BitcodeAttrs, - bounds: &mut FieldBounds, - ) -> Result { - fields - .iter() - .enumerate() - .map(move |(i, field)| { - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - - let name = field_name(i, field, false); - let real_name = field_name(i, field, true); - - let global_name = global_prefix - .map(|global_prefix| { - let ident = - Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); - quote! { #ident } - }) - .unwrap_or_else(|| name.clone()); - - let field_impl = self.field_impl(name, global_name, real_name, &field.ty); - let private = private(); - let bound: Path = parse_quote!(#private::Encode); - bounds.add_bound_type(field.clone(), &field_attrs, bound); - Ok(field_impl) - }) - .collect() - } } -struct Output([TokenStream; Item::COUNT]); - -pub fn derive_impl(mut input: DeriveInput) -> Result { - let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; - let mut generics = input.generics; - let mut bounds = FieldBounds::default(); - - let ident = input.ident; - syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); - - let (output, is_encode_vectored) = match input.data { - Data::Struct(data_struct) => { - let destructure_fields = &destructure_fields(&data_struct.fields); - ( - Output(Item::ALL.map(|item| { - let field_impls = item - .field_impls(None, &data_struct.fields, &attrs, &mut bounds) - .unwrap(); // TODO don't unwrap - item.struct_impl(&ident, destructure_fields, &field_impls) - })), - true, - ) - } - Data::Enum(data_enum) => { - let variant_count = data_enum.variants.len(); - ( - Output(Item::ALL.map(|item| { - if matches!(item, Item::EncodeVectored) { - return Default::default(); // Unimplemented for now. - } - - item.variant_impls( - variant_count, - |i| { - let variant = &data_enum.variants[i]; - let variant_name = &variant.ident; - let destructure_fields = destructure_fields(&variant.fields); - quote! { - #ident::#variant_name #destructure_fields - } - }, - |item, i| { - let variant = &data_enum.variants[i]; - let global_prefix = format!("{}_", &variant.ident); - let attrs = - BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. - item.field_impls( - Some(&global_prefix), - &variant.fields, - &attrs, - &mut bounds, - ) - .unwrap() // TODO don't unwrap. - }, - ) - })), - false, - ) - } - Data::Union(u) => err(&u.union_token, "unions are not supported")?, - }; +pub struct Encode; +impl crate::shared::Derive<{ Item::COUNT }> for Encode { + type Item = Item; + const ALL: [Self::Item; Item::COUNT] = Item::ALL; - bounds.apply_to_generics(&mut generics); - let input_generics = generics.clone(); - let (impl_generics, input_generics, where_clause) = input_generics.split_for_impl(); - let input_ty = quote! { #ident #input_generics }; + fn bound(&self) -> Path { + let private = private(); + parse_quote!(#private::Encode) + } - // Encoder can't contain any lifetimes from input (which would limit reuse of encoder). - remove_lifetimes(&mut generics); - let (encoder_impl_generics, encoder_generics, encoder_where_clause) = generics.split_for_impl(); + fn derive_impl( + &self, + output: [TokenStream; Item::COUNT], + ident: Ident, + mut generics: Generics, + ) -> TokenStream { + let input_generics = generics.clone(); + let (impl_generics, input_generics, where_clause) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; - let Output( - [type_body, default_body, encode_body, encode_vectored_body, collect_into_body, reserve_body], - ) = output; - let encoder_ident = Ident::new(&format!("{ident}Encoder"), Span::call_site()); - let encoder_ty = quote! { #encoder_ident #encoder_generics }; - let private = private(); + // Encoder can't contain any lifetimes from input (which would limit reuse of encoder). + remove_lifetimes(&mut generics); + let (encoder_impl_generics, encoder_generics, encoder_where_clause) = + generics.split_for_impl(); - let encode_vectored = is_encode_vectored.then(|| quote! { - // #[cfg_attr(not(debug_assertions), inline(always))] - // #[inline(never)] - fn encode_vectored<'__v>(&mut self, i: impl Iterator + Clone) where #input_ty: '__v { - #[allow(unused_imports)] - use #private::Buffer as _; - #encode_vectored_body - } - }); + let [type_body, default_body, encode_body, encode_vectored_body, collect_into_body, reserve_body] = + output; + let encoder_ident = Ident::new(&format!("{ident}Encoder"), Span::call_site()); + let encoder_ty = quote! { #encoder_ident #encoder_generics }; + let private = private(); - let ret = quote! { - const _: () = { - impl #impl_generics #private::Encode for #input_ty #where_clause { - type Encoder = #encoder_ty; - } + quote! { + const _: () = { + impl #impl_generics #private::Encode for #input_ty #where_clause { + type Encoder = #encoder_ty; + } - #[allow(non_snake_case)] - pub struct #encoder_ident #encoder_impl_generics #encoder_where_clause { - #type_body - } + #[allow(non_snake_case)] + pub struct #encoder_ident #encoder_impl_generics #encoder_where_clause { + #type_body + } - // Avoids bounding #impl_generics: Default. - impl #encoder_impl_generics std::default::Default for #encoder_ty #encoder_where_clause { - fn default() -> Self { - Self { - #default_body + // Avoids bounding #impl_generics: Default. + impl #encoder_impl_generics std::default::Default for #encoder_ty #encoder_where_clause { + fn default() -> Self { + Self { + #default_body + } } } - } - impl #impl_generics #private::Encoder<#input_ty> for #encoder_ty #where_clause { - #[cfg_attr(not(debug_assertions), inline(always))] - fn encode(&mut self, v: &#input_ty) { - #[allow(unused_imports)] - use #private::Buffer as _; - #encode_body - } - #encode_vectored - } + impl #impl_generics #private::Encoder<#input_ty> for #encoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn encode(&mut self, v: &#input_ty) { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_body + } - impl #encoder_impl_generics #private::Buffer for #encoder_ty #encoder_where_clause { - fn collect_into(&mut self, out: &mut Vec) { - #collect_into_body + // #[cfg_attr(not(debug_assertions), inline(always))] + // #[inline(never)] + fn encode_vectored<'__v>(&mut self, i: impl Iterator + Clone) where #input_ty: '__v { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_vectored_body + } } - fn reserve(&mut self, __additional: std::num::NonZeroUsize) { - #reserve_body + impl #encoder_impl_generics #private::Buffer for #encoder_ty #encoder_where_clause { + fn collect_into(&mut self, out: &mut Vec) { + #collect_into_body + } + + fn reserve(&mut self, __additional: std::num::NonZeroUsize) { + #reserve_body + } } - } - }; - }; - // panic!("{ret}"); - Ok(ret) + }; + } + } } diff --git a/bitcode_derive/src/lib.rs b/bitcode_derive/src/lib.rs index a566032..2874eca 100644 --- a/bitcode_derive/src/lib.rs +++ b/bitcode_derive/src/lib.rs @@ -1,7 +1,10 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::shared::Derive; use proc_macro::TokenStream; use quote::quote; use syn::spanned::Spanned; -use syn::{parse_macro_input, DeriveInput}; +use syn::{parse_macro_input, DeriveInput, Error}; mod attribute; mod bound; @@ -9,27 +12,25 @@ mod decode; mod encode; mod shared; -#[proc_macro_derive(Encode, attributes(bitcode))] -pub fn derive_encode(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - encode::derive_impl(input) - .unwrap_or_else(syn::Error::into_compile_error) - .into() +macro_rules! derive { + ($fn_name:ident, $trait_:ident) => { + #[proc_macro_derive($trait_, attributes(bitcode))] + pub fn $fn_name(input: TokenStream) -> TokenStream { + $trait_ + .derive(parse_macro_input!(input as DeriveInput)) + .unwrap_or_else(Error::into_compile_error) + .into() + } + }; } +derive!(derive_encode, Encode); +derive!(derive_decode, Decode); -#[proc_macro_derive(Decode, attributes(bitcode))] -pub fn derive_decode(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - decode::derive_impl(input) - .unwrap_or_else(syn::Error::into_compile_error) - .into() +pub(crate) fn error(spanned: &impl Spanned, s: &str) -> Error { + Error::new(spanned.span(), s.to_owned()) } -pub(crate) fn error(spanned: &impl Spanned, s: &str) -> syn::Error { - syn::Error::new(spanned.span(), s.to_owned()) -} - -pub(crate) fn err(spanned: &impl Spanned, s: &str) -> Result { +pub(crate) fn err(spanned: &impl Spanned, s: &str) -> Result { Err(error(spanned, s)) } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 5841642..0bd28a6 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -1,9 +1,158 @@ +use crate::attribute::BitcodeAttrs; +use crate::bound::FieldBounds; +use crate::err; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::visit_mut::VisitMut; -use syn::{Field, Fields, GenericParam, Generics, Index, Lifetime, Type, WherePredicate}; +use syn::{ + Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Generics, Index, Lifetime, Path, + Result, Type, WherePredicate, +}; -pub fn destructure_fields(fields: &Fields) -> TokenStream { +type VariantIndex = u8; +pub fn variant_index(i: usize) -> VariantIndex { + i.try_into().unwrap() +} + +pub trait Item: Copy + Sized { + fn field_impl( + self, + field_name: TokenStream, + global_field_name: TokenStream, + real_field_name: TokenStream, + field_type: &Type, + ) -> TokenStream; + + fn struct_impl( + self, + ident: &Ident, + destructure_fields: &TokenStream, + do_fields: &TokenStream, + ) -> TokenStream; + + fn enum_impl( + self, + variant_count: usize, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, + ) -> TokenStream; + + fn field_impls(self, global_prefix: Option<&str>, fields: &Fields) -> TokenStream { + fields + .iter() + .enumerate() + .map(move |(i, field)| { + let name = field_name(i, field, false); + let real_name = field_name(i, field, true); + let global_name = global_prefix + .map(|global_prefix| { + let ident = + Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); + quote! { #ident } + }) + .unwrap_or_else(|| name.clone()); + + self.field_impl(name, global_name, real_name, &field.ty) + }) + .collect() + } +} + +pub trait Derive { + type Item: Item; + const ALL: [Self::Item; ITEM_COUNT]; + + /// `Encode` in `T: Encode`. + fn bound(&self) -> Path; + + /// Generates the derive implementation. + fn derive_impl( + &self, + output: [TokenStream; ITEM_COUNT], + ident: Ident, + generics: Generics, + ) -> TokenStream; + + fn field_attrs( + &self, + fields: &Fields, + attrs: &BitcodeAttrs, + bounds: &mut FieldBounds, + ) -> Result> { + fields + .iter() + .map(|field| { + let field_attrs = BitcodeAttrs::parse_field(&field.attrs, attrs)?; + bounds.add_bound_type(field.clone(), &field_attrs, self.bound()); + Ok(field_attrs) + }) + .collect() + } + + fn derive(&self, mut input: DeriveInput) -> Result { + let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; + let ident = input.ident; + syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); + let mut bounds = FieldBounds::default(); + + let output = match input.data { + Data::Struct(DataStruct { ref fields, .. }) => { + // Only used for adding `bounds`. Would be used by `#[bitcode(with_serde)]`. + let field_attrs = self.field_attrs(fields, &attrs, &mut bounds)?; + let _ = field_attrs; + + let destructure_fields = &destructure_fields(fields); + Self::ALL.map(|item| { + let field_impls = item.field_impls(None, fields); + item.struct_impl(&ident, destructure_fields, &field_impls) + }) + } + Data::Enum(data_enum) => { + let max_variants = VariantIndex::MAX as usize + 1; + if data_enum.variants.len() > max_variants { + return err( + &ident, + &format!("enums with more than {max_variants} variants are not supported"), + ); + } + + // Only used for adding `bounds`. Would be used by `#[bitcode(with_serde)]`. + let variant_attrs = data_enum + .variants + .iter() + .map(|variant| { + let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs)?; + self.field_attrs(&variant.fields, &attrs, &mut bounds) + }) + .collect::>>()?; + let _ = variant_attrs; + + Self::ALL.map(|item| { + item.enum_impl( + data_enum.variants.len(), + |i| { + let variant = &data_enum.variants[i]; + let variant_name = &variant.ident; + let destructure_fields = destructure_fields(&variant.fields); + quote! { + #ident::#variant_name #destructure_fields + } + }, + |item, i| { + let variant = &data_enum.variants[i]; + let global_prefix = format!("{}_", &variant.ident); + item.field_impls(Some(&global_prefix), &variant.fields) + }, + ) + }) + } + Data::Union(_) => err(&ident, "unions are not supported")?, + }; + Ok(self.derive_impl(output, ident, bounds.added_to(input.generics))) + } +} + +fn destructure_fields(fields: &Fields) -> TokenStream { let field_names = fields .iter() .enumerate() @@ -19,7 +168,7 @@ pub fn destructure_fields(fields: &Fields) -> TokenStream { } } -pub fn field_name(i: usize, field: &Field, real: bool) -> TokenStream { +fn field_name(i: usize, field: &Field, real: bool) -> TokenStream { field .ident .as_ref() @@ -60,7 +209,7 @@ impl VisitMut for ReplaceLifetimes<'_> { } } -pub struct ReplaceSelves<'a>(pub &'a Ident); +struct ReplaceSelves<'a>(pub &'a Ident); impl VisitMut for ReplaceSelves<'_> { fn visit_ident_mut(&mut self, ident: &mut Ident) { if ident == "Self" { diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 75fc6e8..67a05b8 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -136,6 +136,7 @@ mod tests { A(u8), } + // cargo expand --lib --tests | grep -A15 Two #[derive(Encode, Decode)] enum Two { A(u8),