Skip to content

Commit

Permalink
WIP: refactoring to allow attributes on discriminants enum.
Browse files Browse the repository at this point in the history
  • Loading branch information
azriel91 committed Sep 20, 2018
1 parent 4085ea7 commit 7fa547c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 22 deletions.
16 changes: 10 additions & 6 deletions strum_macros/src/enum_discriminants.rs
@@ -1,7 +1,7 @@
use proc_macro2::{Span, TokenStream};
use syn;

use helpers::{extract_meta, extract_meta_attrs, unique_meta_attr};
use helpers::{extract_meta, extract_meta_idents, unique_meta_ident, unique_meta_list};

pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
Expand All @@ -13,17 +13,21 @@ pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
};

let type_meta = extract_meta(&ast.attrs);
let discriminant_derives = extract_meta_attrs(&type_meta, "strum_discriminants_derive");
let discriminant_meta = unique_meta_list(&type_meta, "strum_discriminants");
let derives =
discriminant_meta.map_or_else(|| vec![], |meta| extract_meta_idents(&[meta], "derive"));

let derives = quote! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, #(#discriminant_derives),*)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
};

let default_name = syn::Ident::new(
&format!("{}Discriminants", name.to_string()),
Span::call_site(),
);
let enum_discriminants_name =
unique_meta_attr(&type_meta, "strum_discriminants_name").unwrap_or(&default_name);
let discriminants_name = discriminant_meta
.map(|meta| unique_meta_ident(&[meta], "name").unwrap_or(&default_name))
.unwrap_or(&default_name);

let mut discriminants = Vec::new();
for variant in variants {
Expand All @@ -43,7 +47,7 @@ pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
quote!{
/// Auto-generated discriminant enum variants
#derives
#vis enum #enum_discriminants_name {
#vis enum #discriminants_name {
#(#discriminants),*
}
}
Expand Down
63 changes: 56 additions & 7 deletions strum_macros/src/helpers.rs
Expand Up @@ -10,9 +10,21 @@ pub fn extract_meta(attrs: &[Attribute]) -> Vec<Meta> {
.collect()
}

pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta Ident> {
/// Returns the `Meta` s that are `List`s, and match the given `attr` name.
///
/// For example, `extract_meta_lists(metas, "strum")` returns both `#[strum(Something)]` and
/// `#[strum(SomethingElse)]` for the following declaration.
///
/// ```rust,ignore
/// #[derive(Debug)]
/// #[strum(Something)]
/// #[strum(SomethingElse)]
/// struct MyStruct {}
/// ```
pub fn extract_meta_lists_refs<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Meta> {
use syn::NestedMeta;
meta.iter()
metas
.iter()
// Get all the attributes with our tag on them.
.filter_map(|meta| match *meta {
Meta::List(ref metalist) => {
Expand All @@ -24,17 +36,54 @@ pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta
}
_ => None,
}).flat_map(|nested| nested)
// Get all the inner elements as long as they start with ser.
.filter_map(|nested| match *nested {
NestedMeta::Meta(ref meta) => Some(meta),
_ => None,
}).collect()
}

pub fn extract_meta_lists<'meta>(metas: &'meta [Meta], attr: &str) -> Vec<&'meta Meta> {
extract_meta_lists_refs(&metas.iter().collect::<Vec<_>>(), attr)
}

pub fn unique_meta_list<'meta>(metas: &'meta [Meta], attr: &str) -> Option<&'meta Meta> {
let mut curr = extract_meta_lists(&metas, attr);
if curr.len() > 1 {
panic!(
"More than one `{}` attribute found on type, {:?}",
attr, curr
);
}

curr.pop()
}

/// Returns the `Ident`s from the `Meta::List`s that match the given `attr` name.
///
/// For example, `extract_meta_lists(something_metas, "Something")` returns `Abc`, `Def`, and `Ghi` for
/// the following declaration.
///
/// ```rust,ignore
/// #[derive(Debug)]
/// #[strum(Something(Abc, Def), Something(Ghi))]
/// struct MyStruct {}
/// ```
pub fn extract_meta_idents<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Ident> {
extract_meta_lists_refs(metas, attr)
.into_iter()
.filter_map(|meta| match *meta {
NestedMeta::Meta(Meta::Word(ref ident)) => Some(ident),
Meta::Word(ref ident) => Some(ident),
_ => None,
}).collect()
}

pub fn unique_meta_attr<'meta>(attrs: &'meta [Meta], attr: &str) -> Option<&'meta Ident> {
let mut curr = extract_meta_attrs(attrs, attr);
pub fn unique_meta_ident<'meta>(metas: &[&'meta Meta], attr: &str) -> Option<&'meta Ident> {
let mut curr = extract_meta_idents(metas, attr);
if curr.len() > 1 {
panic!("More than one `{}` attribute found on type", attr);
panic!(
"More than one `{}` attribute found on type: {:?}",
attr, curr
);
}

curr.pop()
Expand Down
5 changes: 1 addition & 4 deletions strum_macros/src/lib.rs
Expand Up @@ -115,10 +115,7 @@ pub fn enum_properties(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
toks.into()
}

#[proc_macro_derive(
EnumDiscriminants,
attributes(strum, strum_discriminants_derive, strum_discriminants_name)
)]
#[proc_macro_derive(EnumDiscriminants, attributes(strum, strum_discriminants))]
pub fn enum_discriminants(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse(input).unwrap();

Expand Down
9 changes: 4 additions & 5 deletions strum_tests/tests/enum_discriminants.rs
Expand Up @@ -6,7 +6,7 @@ use strum::IntoEnumIterator;

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
#[strum_discriminants(derive(EnumIter))]
enum Simple {
Variant0,
Variant1,
Expand All @@ -25,7 +25,7 @@ struct NonDefault;

#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
#[strum_discriminants(derive(EnumIter))]
enum WithFields {
Variant0(NonDefault),
Variant1 { a: NonDefault },
Expand All @@ -47,7 +47,7 @@ trait Bar {}

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
#[strum_discriminants(derive(EnumIter))]
enum Complicated<U: Foo, V: Bar> {
/// With Docs
A(U),
Expand Down Expand Up @@ -87,8 +87,7 @@ fn with_default_test() {

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
#[strum_discriminants_name(EnumBoo)]
#[strum_discriminants(name(EnumBoo), derive(EnumIter))]
enum Renamed {
Variant0(bool),
Variant1(i32),
Expand Down

0 comments on commit 7fa547c

Please sign in to comment.