Skip to content

Commit

Permalink
Allow repr(transparent) to be used generically in derive(Pod) (#139)
Browse files Browse the repository at this point in the history
* Enabled transparent generics

* Move trait checks to implementation block

* Replace add_trait_marker impl
  • Loading branch information
notgull committed Nov 3, 2022
1 parent 7b67524 commit 518baf9
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 116 deletions.
46 changes: 32 additions & 14 deletions derive/src/lib.rs
Expand Up @@ -9,7 +9,8 @@ use quote::quote;
use syn::{parse_macro_input, DeriveInput, Result};

use crate::traits::{
AnyBitPattern, Contiguous, Derivable, CheckedBitPattern, NoUninit, Pod, TransparentWrapper, Zeroable,
AnyBitPattern, CheckedBitPattern, Contiguous, Derivable, NoUninit, Pod,
TransparentWrapper, Zeroable,
};

/// Derive the `Pod` trait for a struct
Expand Down Expand Up @@ -56,8 +57,9 @@ pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_anybitpattern(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<AnyBitPattern>(parse_macro_input!(input as DeriveInput));
let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
input as DeriveInput
));

proc_macro::TokenStream::from(expanded)
}
Expand Down Expand Up @@ -99,8 +101,8 @@ pub fn derive_zeroable(
/// for the `NoUninit` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait bounds,
/// i.e. the type must be `Sized + Copy + 'static`):
/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
/// bounds, i.e. the type must be `Sized + Copy + 'static`):
///
/// If applied to a struct:
/// - All fields in the struct must implement `NoUninit`
Expand Down Expand Up @@ -129,9 +131,9 @@ pub fn derive_no_uninit(
/// definition and `is_valid_bit_pattern` method for the type automatically.
///
/// The following constraints need to be satisfied for the macro to succeed
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern` subtrait bounds,
/// i.e. are guaranteed by the requirements of the `NoUninit` trait which `CheckedBitPattern`
/// is a subtrait of):
/// (the rest of the constraints are guaranteed by the `CheckedBitPattern`
/// subtrait bounds, i.e. are guaranteed by the requirements of the `NoUninit`
/// trait which `CheckedBitPattern` is a subtrait of):
///
/// If applied to a struct:
/// - All fields must implement `CheckedBitPattern`
Expand All @@ -142,8 +144,9 @@ pub fn derive_no_uninit(
pub fn derive_maybe_pod(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let expanded =
derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(input as DeriveInput));
let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
input as DeriveInput
));

proc_macro::TokenStream::from(expanded)
}
Expand Down Expand Up @@ -228,17 +231,19 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
}

fn derive_marker_trait_inner<Trait: Derivable>(
input: DeriveInput,
mut input: DeriveInput,
) -> Result<TokenStream> {
// Enforce Pod on all generic fields.
let trait_ = Trait::ident(&input)?;
add_trait_marker(&mut input.generics, &trait_);

let name = &input.ident;

let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();

let trait_ = Trait::ident();
Trait::check_attributes(&input.data, &input.attrs)?;
let asserts = Trait::asserts(&input)?;
let trait_params = Trait::generic_params(&input)?;
let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input)?;

let implies_trait = if let Some(implies_trait) = Trait::implies_trait() {
Expand All @@ -252,10 +257,23 @@ fn derive_marker_trait_inner<Trait: Derivable>(

#trait_impl_extras

unsafe impl #impl_generics #trait_ #trait_params for #name #ty_generics #where_clause {
unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
#trait_impl
}

#implies_trait
})
}

/// Add a trait marker to the generics if it is not already present
fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
// Get each generic type parameter.
let type_params = generics
.type_params()
.map(|param| &param.ident)
.map(|param| syn::parse_quote!(
#param: #trait_name
)).collect::<Vec<syn::WherePredicate>>();

generics.make_where_clause().predicates.extend(type_params);
}

0 comments on commit 518baf9

Please sign in to comment.