From 518baf9c0b73c92b4ea4406fe15e005c6d71535a Mon Sep 17 00:00:00 2001 From: John Nunley Date: Thu, 3 Nov 2022 06:53:01 -0700 Subject: [PATCH] Allow `repr(transparent)` to be used generically in `derive(Pod)` (#139) * Enabled transparent generics * Move trait checks to implementation block * Replace add_trait_marker impl --- derive/src/lib.rs | 46 +++++++---- derive/src/traits.rs | 181 ++++++++++++++++++++---------------------- derive/tests/basic.rs | 34 ++++++-- src/internal.rs | 4 +- 4 files changed, 149 insertions(+), 116 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 8a877f8..b5c106c 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -9,7 +9,8 @@ use quote::quote; use syn::{parse_macro_input, DeriveInput, Result}; use crate::traits::{ - AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable, + AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod, + TransparentWrapper, Zeroable, }; /// Derive the `Pod` trait for a struct @@ -56,8 +57,9 @@ pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_anybitpattern( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let expanded = - derive_marker_trait::(parse_macro_input!(input as DeriveInput)); + let expanded = derive_marker_trait::(parse_macro_input!( + input as DeriveInput + )); proc_macro::TokenStream::from(expanded) } @@ -99,8 +101,8 @@ pub fn derive_zeroable( /// for the `NoUninit` trait. /// /// The following constraints need to be satisfied for the macro to succeed -/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait bounds, -/// i.e. the type must be `Sized + Copy + 'static`): +/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait +/// bounds, i.e. the type must be `Sized + Copy + 'static`): /// /// If applied to a struct: /// - All fields in the struct must implement `NoUninit` @@ -129,9 +131,9 @@ pub fn derive_no_uninit( /// definition and `is_valid_bit_pattern` method for the type automatically. /// /// The following constraints need to be satisfied for the macro to succeed -/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` subtrait bounds, -/// i.e. are guaranteed by the requirements of the `NoUninit` trait which `CheckedBitPattern` -/// is a subtrait of): +/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` +/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit` +/// trait which `CheckedBitPattern` is a subtrait of): /// /// If applied to a struct: /// - All fields must implement `CheckedBitPattern` @@ -142,8 +144,9 @@ pub fn derive_no_uninit( pub fn derive_maybe_pod( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let expanded = - derive_marker_trait::(parse_macro_input!(input as DeriveInput)); + let expanded = derive_marker_trait::(parse_macro_input!( + input as DeriveInput + )); proc_macro::TokenStream::from(expanded) } @@ -228,17 +231,19 @@ fn derive_marker_trait(input: DeriveInput) -> TokenStream { } fn derive_marker_trait_inner( - input: DeriveInput, + mut input: DeriveInput, ) -> Result { + // Enforce Pod on all generic fields. + let trait_ = Trait::ident(&input)?; + add_trait_marker(&mut input.generics, &trait_); + let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let trait_ = Trait::ident(); Trait::check_attributes(&input.data, &input.attrs)?; let asserts = Trait::asserts(&input)?; - let trait_params = Trait::generic_params(&input)?; let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?; let implies_trait = if let Some(implies_trait) = Trait::implies_trait() { @@ -252,10 +257,23 @@ fn derive_marker_trait_inner( #trait_impl_extras - unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause { + unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause { #trait_impl } #implies_trait }) } + +/// Add a trait marker to the generics if it is not already present +fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) { + // Get each generic type parameter. + let type_params = generics + .type_params() + .map(|param| ¶m.ident) + .map(|param| syn::parse_quote!( + #param: #trait_name + )).collect::>(); + + generics.make_where_clause().predicates.extend(type_params); +} diff --git a/derive/src/traits.rs b/derive/src/traits.rs index 815e77c..2e3fec9 100644 --- a/derive/src/traits.rs +++ b/derive/src/traits.rs @@ -1,42 +1,35 @@ #![allow(unused_imports)] use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; -use syn::{*, - parse::{Parse, Parser, ParseStream}, +use syn::{ + parse::{Parse, ParseStream, Parser}, punctuated::Punctuated, spanned::Spanned, - Result, + Result, *, }; macro_rules! bail { - ($msg:expr $(,)?) => ( + ($msg:expr $(,)?) => { return Err(Error::new(Span::call_site(), &$msg[..])) - ); + }; - ( $msg:expr => $span_to_blame:expr $(,)? ) => ( + ( $msg:expr => $span_to_blame:expr $(,)? ) => { return Err(Error::new_spanned(&$span_to_blame, $msg)) - ); + }; } pub trait Derivable { - fn ident() -> TokenStream; + fn ident(input: &DeriveInput) -> Result; fn implies_trait() -> Option { None } - fn generic_params(_input: &DeriveInput) -> Result { - Ok(quote!()) - } fn asserts(_input: &DeriveInput) -> Result { Ok(quote!()) } - fn check_attributes( - _ty: &Data, _attributes: &[Attribute], - ) -> Result<()> { + fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> { Ok(()) } - fn trait_impl( - _input: &DeriveInput, - ) -> Result<(TokenStream, TokenStream)> { + fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { Ok((quote!(), quote!())) } } @@ -44,14 +37,15 @@ pub trait Derivable { pub struct Pod; impl Derivable for Pod { - fn ident() -> TokenStream { - quote!(::bytemuck::Pod) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::Pod)) } fn asserts(input: &DeriveInput) -> Result { let repr = get_repr(&input.attrs)?; - let completly_packed = repr.packed == Some(1); + let completly_packed = + repr.packed == Some(1) || repr.repr == Repr::Transparent; if !completly_packed && !input.generics.params.is_empty() { bail!("\ @@ -69,7 +63,7 @@ impl Derivable for Pod { None }; let assert_fields_are_pod = - generate_fields_are_trait(input, Self::ident())?; + generate_fields_are_trait(input, Self::ident(input)?)?; Ok(quote!( #assert_no_padding @@ -81,9 +75,7 @@ impl Derivable for Pod { } } - fn check_attributes( - _ty: &Data, attributes: &[Attribute], - ) -> Result<()> { + fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match repr.repr { Repr::C => Ok(()), @@ -98,8 +90,8 @@ impl Derivable for Pod { pub struct AnyBitPattern; impl Derivable for AnyBitPattern { - fn ident() -> TokenStream { - quote!(::bytemuck::AnyBitPattern) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::AnyBitPattern)) } fn implies_trait() -> Option { @@ -109,8 +101,10 @@ impl Derivable for AnyBitPattern { fn asserts(input: &DeriveInput) -> Result { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` - Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), - Data::Enum(_) => bail!("Deriving AnyBitPattern is not supported for enums"), + Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?), + Data::Enum(_) => { + bail!("Deriving AnyBitPattern is not supported for enums") + } } } } @@ -118,14 +112,14 @@ impl Derivable for AnyBitPattern { pub struct Zeroable; impl Derivable for Zeroable { - fn ident() -> TokenStream { - quote!(::bytemuck::Zeroable) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::Zeroable)) } fn asserts(input: &DeriveInput) -> Result { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` - Data::Struct(_) => generate_fields_are_trait(input, Self::ident()), + Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?), Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"), } } @@ -134,13 +128,11 @@ impl Derivable for Zeroable { pub struct NoUninit; impl Derivable for NoUninit { - fn ident() -> TokenStream { - quote!(::bytemuck::NoUninit) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::NoUninit)) } - fn check_attributes( - ty: &Data, attributes: &[Attribute], - ) -> Result<()> { + fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => match repr.repr { @@ -165,7 +157,7 @@ impl Derivable for NoUninit { Data::Struct(DataStruct { .. }) => { let assert_no_padding = generate_assert_no_padding(&input)?; let assert_fields_are_no_padding = - generate_fields_are_trait(&input, Self::ident())?; + generate_fields_are_trait(&input, Self::ident(input)?)?; Ok(quote!( #assert_no_padding @@ -179,13 +171,11 @@ impl Derivable for NoUninit { Ok(quote!()) } } - Data::Union(_) => bail!("NoUninit cannot be derived for unions"), // shouldn't be possible since we already error in attribute check for this case + Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */ } } - fn trait_impl( - _input: &DeriveInput, - ) -> Result<(TokenStream, TokenStream)> { + fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { Ok((quote!(), quote!())) } } @@ -193,13 +183,11 @@ impl Derivable for NoUninit { pub struct CheckedBitPattern; impl Derivable for CheckedBitPattern { - fn ident() -> TokenStream { - quote!(::bytemuck::CheckedBitPattern) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::CheckedBitPattern)) } - fn check_attributes( - ty: &Data, attributes: &[Attribute], - ) -> Result<()> { + fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => match repr.repr { @@ -223,24 +211,23 @@ impl Derivable for CheckedBitPattern { match &input.data { Data::Struct(DataStruct { .. }) => { let assert_fields_are_maybe_pod = - generate_fields_are_trait(&input, Self::ident())?; + generate_fields_are_trait(&input, Self::ident(input)?)?; Ok(assert_fields_are_maybe_pod) } - Data::Enum(_) => Ok(quote!()), // nothing needed, already guaranteed OK by NoUninit - Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case + Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed + * OK by NoUninit */ + Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } - fn trait_impl( - input: &DeriveInput, - ) -> Result<(TokenStream, TokenStream)> { + fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { match &input.data { Data::Struct(DataStruct { fields, .. }) => { generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs) - }, + } Data::Enum(_) => generate_checked_bit_pattern_enum(input), - Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), // shouldn't be possible since we already error in attribute check for this case + Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } } @@ -266,20 +253,20 @@ impl TransparentWrapper { } impl Derivable for TransparentWrapper { - fn ident() -> TokenStream { - quote!(::bytemuck::TransparentWrapper) - } - - fn generic_params(input: &DeriveInput) -> Result { + fn ident(input: &DeriveInput) -> Result { let fields = get_struct_fields(input)?; - match Self::get_wrapper_type(&input.attrs, &fields) { - | Some(ty) => Ok(quote!(<#ty>)), - | None => bail!("\ + let ty = match Self::get_wrapper_type(&input.attrs, &fields) { + Some(ty) => ty, + None => bail!( + "\ when deriving TransparentWrapper for a struct with more than one field \ you need to specify the transparent field using #[transparent(T)]\ - "), - } + " + ), + }; + + Ok(syn::parse_quote!(::bytemuck::TransparentWrapper<#ty>)) } fn asserts(input: &DeriveInput) -> Result { @@ -301,15 +288,15 @@ impl Derivable for TransparentWrapper { } } - fn check_attributes( - _ty: &Data, attributes: &[Attribute], - ) -> Result<()> { + fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match repr.repr { Repr::Transparent => Ok(()), _ => { - bail!("TransparentWrapper requires the struct to be #[repr(transparent)]") + bail!( + "TransparentWrapper requires the struct to be #[repr(transparent)]" + ) } } } @@ -318,13 +305,11 @@ impl Derivable for TransparentWrapper { pub struct Contiguous; impl Derivable for Contiguous { - fn ident() -> TokenStream { - quote!(::bytemuck::Contiguous) + fn ident(_: &DeriveInput) -> Result { + Ok(syn::parse_quote!(::bytemuck::Contiguous)) } - fn trait_impl( - input: &DeriveInput, - ) -> Result<(TokenStream, TokenStream)> { + fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { let repr = get_repr(&input.attrs)?; let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() { @@ -422,7 +407,8 @@ fn generate_checked_bit_pattern_struct( let field_name = &field_names[..]; let field_ty = &field_tys[..]; - let derive_dbg = quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); + let derive_dbg = + quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); Ok(( quote! { @@ -456,7 +442,11 @@ fn generate_checked_bit_pattern_enum( (i64::max_value(), i64::min_value(), 0), |(min, max, count), res| { let discriminant = res?; - Ok::<_, Error>((i64::min(min, discriminant), i64::max(max, discriminant), count + 1)) + Ok::<_, Error>(( + i64::min(min, discriminant), + i64::max(max, discriminant), + count + 1, + )) }, )?; @@ -503,9 +493,7 @@ fn generate_checked_bit_pattern_enum( /// Check that a struct has no padding by asserting that the size of the struct /// is equal to the sum of the size of it's fields -fn generate_assert_no_padding( - input: &DeriveInput, -) -> Result { +fn generate_assert_no_padding(input: &DeriveInput) -> Result { let struct_type = &input.ident; let span = input.ident.span(); let fields = get_fields(input)?; @@ -529,7 +517,7 @@ fn generate_assert_no_padding( /// Check that all fields implement a given trait fn generate_fields_are_trait( - input: &DeriveInput, trait_: TokenStream, + input: &DeriveInput, trait_: syn::Path, ) -> Result { let (impl_generics, _ty_generics, where_clause) = input.generics.split_for_impl(); @@ -574,28 +562,30 @@ fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option { fn get_repr(attributes: &[Attribute]) -> Result { attributes .iter() - .filter_map(|attr| if attr.path.is_ident("repr") { - Some(attr.parse_args::()) - } else { - None + .filter_map(|attr| { + if attr.path.is_ident("repr") { + Some(attr.parse_args::()) + } else { + None + } }) .try_fold(Representation::default(), |a, b| { let b = b?; Ok(Representation { repr: match (a.repr, b.repr) { - | (a, Repr::Rust) => a, - | (Repr::Rust, b) => b, - | _ => bail!("conflicting representation hints"), + (a, Repr::Rust) => a, + (Repr::Rust, b) => b, + _ => bail!("conflicting representation hints"), }, packed: match (a.packed, b.packed) { - | (a, None) => a, - | (None, b) => b, - | _ => bail!("conflicting representation hints"), + (a, None) => a, + (None, b) => b, + _ => bail!("conflicting representation hints"), }, align: match (a.align, b.align) { - | (a, None) => a, - | (None, b) => b, - | _ => bail!("conflicting representation hints"), + (a, None) => a, + (None, b) => b, + _ => bail!("conflicting representation hints"), }, }) }) @@ -719,7 +709,8 @@ macro_rules! mk_repr {( )); } } -)} use mk_repr; +)} +use mk_repr; struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { inner: I, @@ -767,9 +758,7 @@ fn parse_int_expr(expr: &Expr) -> Result { Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { parse_int_expr(expr).map(|int| -int) } - Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => { - int.base10_parse() - } + Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(), _ => bail!("Not an integer expression"), } } diff --git a/derive/tests/basic.rs b/derive/tests/basic.rs index 9d5667d..25d1781 100644 --- a/derive/tests/basic.rs +++ b/derive/tests/basic.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] use bytemuck::{ - AnyBitPattern, Contiguous, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable, + AnyBitPattern, CheckedBitPattern, Contiguous, NoUninit, Pod, + TransparentWrapper, Zeroable, }; use std::marker::PhantomData; @@ -138,9 +139,26 @@ struct CheckedBitPatternStruct { #[repr(C)] struct AnyBitPatternTest { a: u16, - b: u16 + b: u16, } +/// ```compile_fail +/// use bytemuck::{Pod, Zeroable}; +/// +/// #[derive(Pod, Zeroable)] +/// #[repr(transparent)] +/// struct TransparentSingle(T); +/// +/// struct NotPod(u32); +/// +/// let _: u32 = bytemuck::cast(TransparentSingle(NotPod(0u32))); +/// ``` +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Pod, Zeroable, TransparentWrapper, +)] +#[repr(transparent)] +struct NewtypeWrapperTest(T); + #[test] fn fails_cast_contiguous() { let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5); @@ -149,7 +167,8 @@ fn fails_cast_contiguous() { #[test] fn passes_cast_contiguous() { - let res = bytemuck::checked::from_bytes::(&[2u8]); + let res = + bytemuck::checked::from_bytes::(&[2u8]); assert_eq!(*res, CheckedBitPatternEnumWithValues::C); } @@ -162,7 +181,9 @@ fn fails_cast_noncontiguous() { #[test] fn passes_cast_noncontiguous() { let res = - bytemuck::checked::from_bytes::(&[56u8]); + bytemuck::checked::from_bytes::(&[ + 56u8, + ]); assert_eq!(*res, CheckedBitPatternEnumNonContiguous::E); } @@ -177,7 +198,10 @@ fn fails_cast_struct() { fn passes_cast_struct() { let pod = [0u8, 8u8]; let res = bytemuck::checked::from_bytes::(&pod); - assert_eq!(*res, CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B }); + assert_eq!( + *res, + CheckedBitPatternStruct { a: 0, b: CheckedBitPatternEnumNonContiguous::B } + ); } #[test] diff --git a/src/internal.rs b/src/internal.rs index fd60779..2984d26 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -23,7 +23,9 @@ possibility code branch. #[cfg(not(target_arch = "spirv"))] #[cold] #[inline(never)] -pub(crate) fn something_went_wrong(_src: &str, _err: D) -> ! { +pub(crate) fn something_went_wrong( + _src: &str, _err: D, +) -> ! { // Note(Lokathor): Keeping the panic here makes the panic _formatting_ go // here too, which helps assembly readability and also helps keep down // the inline pressure.