Skip to content

Commit

Permalink
Handle errors instead of panicking in derive macro.
Browse files Browse the repository at this point in the history
  • Loading branch information
caibear committed Mar 14, 2024
1 parent 9aaaf39 commit 99b7b67
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 374 deletions.
5 changes: 2 additions & 3 deletions bitcode_derive/src/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ impl BitcodeAttr {
return err(nested, "duplicate");
}
*b = Some(bound_type);
Ok(())
} else {
return err(nested, "can only apply bound to fields");
err(nested, "can only apply bound to fields")
}
}
}
Ok(())
}
}

Expand Down Expand Up @@ -81,7 +81,6 @@ impl BitcodeAttrs {
Ok(ret)
}

#[allow(unused)] // TODO
pub fn parse_variant(attrs: &[Attribute], _derive_attrs: &Self) -> Result<Self> {
let mut ret = Self::new(AttrType::Variant);
ret.parse_inner(attrs)?;
Expand Down
7 changes: 4 additions & 3 deletions bitcode_derive/src/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ impl FieldBounds {
}
}

pub fn apply_to_generics(self, generics: &mut syn::Generics) {
pub fn added_to(self, mut generics: syn::Generics) -> syn::Generics {
for (bound, (fields, extra_bound_types)) in self.bounds {
*generics = with_bound(&fields, extra_bound_types, generics, &bound);
generics = with_bound(&fields, extra_bound_types, &generics, &bound);
}
generics
}
}

// Based on https://github.com/serde-rs/serde/blob/0c6a2bbf794abe966a4763f5b7ff23acb535eb7f/serde_derive/src/bound.rs#L94-L314
pub fn with_bound(
fn with_bound(
fields: &[syn::Field],
extra_bound_types: Vec<syn::Type>,
generics: &syn::Generics,
Expand Down
275 changes: 98 additions & 177 deletions bitcode_derive/src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use crate::attribute::BitcodeAttrs;
use crate::bound::FieldBounds;
use crate::shared::{
destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves,
};
use crate::{err, private};
use crate::private;
use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{
parse_quote, Data, DeriveInput, Fields, GenericParam, Lifetime, LifetimeParam, Path,
PredicateLifetime, Result, Type, WherePredicate,
parse_quote, GenericParam, Generics, Lifetime, LifetimeParam, Path, PredicateLifetime, Type,
WherePredicate,
};

const DE_LIFETIME: &str = "__de";
Expand All @@ -17,8 +13,7 @@ fn de_lifetime() -> Lifetime {
}

#[derive(Copy, Clone)]
#[repr(u8)]
enum Item {
pub enum Item {
Type,
Default,
Populate,
Expand All @@ -35,7 +30,9 @@ impl Item {
Self::DecodeInPlace,
];
const COUNT: usize = Self::ALL.len();
}

impl crate::shared::Item for Item {
fn field_impl(
self,
field_name: TokenStream,
Expand Down Expand Up @@ -89,11 +86,11 @@ impl Item {
}
}

pub fn variant_impls(
fn enum_impl(
self,
variant_count: usize,
mut pattern: impl FnMut(usize) -> TokenStream,
mut inner: impl FnMut(Self, usize) -> TokenStream,
pattern: impl Fn(usize) -> TokenStream,
inner: impl Fn(Self, usize) -> TokenStream,
) -> TokenStream {
// if variant_count is 0 or 1 variants don't have to be decoded.
let decode_variants = variant_count > 1;
Expand Down Expand Up @@ -146,9 +143,7 @@ impl Item {
if inner.is_empty() {
quote! {}
} else {
let i: u8 = i
.try_into()
.expect("enums with more than 256 variants are not supported"); // TODO don't panic.
let i = variant_index(i);
let length = decode_variants
.then(|| {
quote! {
Expand Down Expand Up @@ -176,7 +171,7 @@ impl Item {
unsafe { std::hint::unreachable_unchecked() }
};
}
let mut pattern = |i: usize| {
let pattern = |i: usize| {
let pattern = pattern(i);
matches!(self, Self::DecodeInPlace)
.then(|| {
Expand All @@ -194,7 +189,7 @@ impl Item {
.map(|i| {
let inner = inner(item, i);
let pattern = pattern(i);
let i: u8 = i.try_into().unwrap(); // Already checked in reserve impl.
let i = variant_index(i);
quote! {
#i => {
#inner
Expand Down Expand Up @@ -225,184 +220,110 @@ impl Item {
}
}
}
}

// TODO dedup with encode.rs
fn field_impls(
self,
global_prefix: Option<&str>,
fields: &Fields,
parent_attrs: &BitcodeAttrs,
bounds: &mut FieldBounds,
) -> Result<TokenStream> {
fields
.iter()
.enumerate()
.map(move |(i, field)| {
let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?;

let name = field_name(i, field, false);
let real_name = field_name(i, field, true);
pub struct Decode;
impl crate::shared::Derive<{ Item::COUNT }> for Decode {
type Item = Item;
const ALL: [Self::Item; Item::COUNT] = Item::ALL;

let global_name = global_prefix
.map(|global_prefix| {
let ident =
Ident::new(&format!("{global_prefix}{name}"), Span::call_site());
quote! { #ident }
})
.unwrap_or_else(|| name.clone());
fn bound(&self) -> Path {
let private = private();
let de = de_lifetime();
parse_quote!(#private::Decode<#de>)
}

let field_impl = self.field_impl(name, global_name, real_name, &field.ty);
fn derive_impl(
&self,
output: [TokenStream; Item::COUNT],
ident: Ident,
mut generics: Generics,
) -> TokenStream {
let input_generics = generics.clone();
let (_, input_generics, _) = input_generics.split_for_impl();
let input_ty = quote! { #ident #input_generics };

let private = private();
let de = de_lifetime();
let bound: Path = parse_quote!(#private::Decode<#de>);
bounds.add_bound_type(field.clone(), &field_attrs, bound);
Ok(field_impl)
})
.collect()
}
}
// Add 'de lifetime after isolating input_generics.
let de = de_lifetime();
let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime {
lifetime: de.clone(),
colon_token: parse_quote!(:),
bounds: generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Lifetime(p) = p {
Some(p.lifetime.clone())
} else {
None
}
})
.collect(),
});

struct Output([TokenStream; Item::COUNT]);
// Push de_param after bounding 'de: 'a.
let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone()));
generics.params.push(de_param.clone());
generics
.make_where_clause()
.predicates
.push(de_where_predicate);

impl Output {
fn haunt(mut self) -> Self {
let type_ = &mut self.0[Item::Type as usize];
if type_.is_empty() {
let de = de_lifetime();
*type_ = quote! { __spooky: std::marker::PhantomData<&#de ()>, };
}
let default = &mut self.0[Item::Default as usize];
if default.is_empty() {
*default = quote! { __spooky: Default::default(), };
}
self
}
}
let combined_generics = generics.clone();
let (impl_generics, _, where_clause) = combined_generics.split_for_impl();

pub fn derive_impl(mut input: DeriveInput) -> Result<TokenStream> {
let attrs = BitcodeAttrs::parse_derive(&input.attrs)?;
let mut generics = input.generics;
let mut bounds = FieldBounds::default();
// Decoder can't contain any lifetimes from input (which would limit reuse of decoder).
remove_lifetimes(&mut generics);
generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it.
let (decoder_impl_generics, decoder_generics, decoder_where_clause) =
generics.split_for_impl();

let ident = input.ident;
syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data);
let output = (match input.data {
Data::Struct(data_struct) => {
let destructure_fields = &destructure_fields(&data_struct.fields);
Output(Item::ALL.map(|item| {
let field_impls = item
.field_impls(None, &data_struct.fields, &attrs, &mut bounds)
.unwrap(); // TODO don't unwrap
item.struct_impl(&ident, destructure_fields, &field_impls)
}))
let [mut type_body, mut default_body, populate_body, decode_in_place_body] = output;
if type_body.is_empty() {
type_body = quote! { __spooky: std::marker::PhantomData<&#de ()>, };
}
Data::Enum(data_enum) => {
let variant_count = data_enum.variants.len();
Output(Item::ALL.map(|item| {
item.variant_impls(
variant_count,
|i| {
let variant = &data_enum.variants[i];
let variant_name = &variant.ident;
let destructure_fields = destructure_fields(&variant.fields);
quote! {
#ident::#variant_name #destructure_fields
}
},
|item, i| {
let variant = &data_enum.variants[i];
let global_prefix = format!("{}_", &variant.ident);
let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap.
item.field_impls(Some(&global_prefix), &variant.fields, &attrs, &mut bounds)
.unwrap() // TODO don't unwrap.
},
)
}))
if default_body.is_empty() {
default_body = quote! { __spooky: Default::default(), };
}
Data::Union(u) => err(&u.union_token, "unions are not supported")?,
})
.haunt();

bounds.apply_to_generics(&mut generics);
let input_generics = generics.clone();
let (_, input_generics, _) = input_generics.split_for_impl();
let input_ty = quote! { #ident #input_generics };
let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site());
let decoder_ty = quote! { #decoder_ident #decoder_generics };
let private = private();

// Add 'de lifetime after isolating input_generics.
let de = de_lifetime();
let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime {
lifetime: de.clone(),
colon_token: parse_quote!(:),
bounds: generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Lifetime(p) = p {
Some(p.lifetime.clone())
} else {
None
quote! {
const _: () = {
impl #impl_generics #private::Decode<#de> for #input_ty #where_clause {
type Decoder = #decoder_ty;
}
})
.collect(),
});

// Push de_param after bounding 'de: 'a.
let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone()));
generics.params.push(de_param.clone());
generics
.make_where_clause()
.predicates
.push(de_where_predicate);

let combined_generics = generics.clone();
let (impl_generics, _, where_clause) = combined_generics.split_for_impl();

// Decoder can't contain any lifetimes from input (which would limit reuse of decoder).
remove_lifetimes(&mut generics);
generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it.
let (decoder_impl_generics, decoder_generics, decoder_where_clause) = generics.split_for_impl();

let Output([type_body, default_body, populate_body, decode_in_place_body]) = output;
let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site());
let decoder_ty = quote! { #decoder_ident #decoder_generics };
let private = private();

let ret = quote! {
const _: () = {
impl #impl_generics #private::Decode<#de> for #input_ty #where_clause {
type Decoder = #decoder_ty;
}

#[allow(non_snake_case)]
pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause {
#type_body
}
#[allow(non_snake_case)]
pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause {
#type_body
}

// Avoids bounding #impl_generics: Default.
impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause {
fn default() -> Self {
Self {
#default_body
// Avoids bounding #impl_generics: Default.
impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause {
fn default() -> Self {
Self {
#default_body
}
}
}
}

impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause {
fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> {
#populate_body
Ok(())
impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause {
fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> {
#populate_body
Ok(())
}
}
}

impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause {
#[cfg_attr(not(debug_assertions), inline(always))]
fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) {
#decode_in_place_body
impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause {
#[cfg_attr(not(debug_assertions), inline(always))]
fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) {
#decode_in_place_body
}
}
}
};
};
// panic!("{ret}");
Ok(ret)
};
}
}
}
Loading

0 comments on commit 99b7b67

Please sign in to comment.