diff --git a/strum_macros/src/enum_discriminants.rs b/strum_macros/src/enum_discriminants.rs
index f819a702..6cd61707 100644
--- a/strum_macros/src/enum_discriminants.rs
+++ b/strum_macros/src/enum_discriminants.rs
@@ -1,7 +1,7 @@
use proc_macro2::{Span, TokenStream};
use syn;
-use helpers::{extract_meta, extract_meta_attrs, unique_meta_attr};
+use helpers::{extract_meta, extract_meta_idents, unique_meta_ident, unique_meta_list};
pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
@@ -13,17 +13,21 @@ pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
};
let type_meta = extract_meta(&ast.attrs);
- let discriminant_derives = extract_meta_attrs(&type_meta, "strum_discriminants_derive");
+ let discriminant_meta = unique_meta_list(&type_meta, "strum_discriminants");
+ let derives =
+ discriminant_meta.map_or_else(|| vec![], |meta| extract_meta_idents(&[meta], "derive"));
+
let derives = quote! {
- #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#discriminant_derives),*)]
+ #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
};
let default_name = syn::Ident::new(
&format!("{}Discriminants", name.to_string()),
Span::call_site(),
);
- let enum_discriminants_name =
- unique_meta_attr(&type_meta, "strum_discriminants_name").unwrap_or(&default_name);
+ let discriminants_name = discriminant_meta
+ .map(|meta| unique_meta_ident(&[meta], "name").unwrap_or(&default_name))
+ .unwrap_or(&default_name);
let mut discriminants = Vec::new();
for variant in variants {
@@ -43,7 +47,7 @@ pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
quote!{
/// Auto-generated discriminant enum variants
#derives
- #vis enum #enum_discriminants_name {
+ #vis enum #discriminants_name {
#(#discriminants),*
}
}
diff --git a/strum_macros/src/helpers.rs b/strum_macros/src/helpers.rs
index 5388f993..385fa6ef 100644
--- a/strum_macros/src/helpers.rs
+++ b/strum_macros/src/helpers.rs
@@ -10,9 +10,21 @@ pub fn extract_meta(attrs: &[Attribute]) -> Vec {
.collect()
}
-pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta Ident> {
+/// Returns the `Meta` s that are `List`s, and match the given `attr` name.
+///
+/// For example, `extract_meta_lists(metas, "strum")` returns both `#[strum(Something)]` and
+/// `#[strum(SomethingElse)]` for the following declaration.
+///
+/// ```rust,ignore
+/// #[derive(Debug)]
+/// #[strum(Something)]
+/// #[strum(SomethingElse)]
+/// struct MyStruct {}
+/// ```
+pub fn extract_meta_lists_refs<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Meta> {
use syn::NestedMeta;
- meta.iter()
+ metas
+ .iter()
// Get all the attributes with our tag on them.
.filter_map(|meta| match *meta {
Meta::List(ref metalist) => {
@@ -24,17 +36,54 @@ pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta
}
_ => None,
}).flat_map(|nested| nested)
- // Get all the inner elements as long as they start with ser.
+ .filter_map(|nested| match *nested {
+ NestedMeta::Meta(ref meta) => Some(meta),
+ _ => None,
+ }).collect()
+}
+
+pub fn extract_meta_lists<'meta>(metas: &'meta [Meta], attr: &str) -> Vec<&'meta Meta> {
+ extract_meta_lists_refs(&metas.iter().collect::>(), attr)
+}
+
+pub fn unique_meta_list<'meta>(metas: &'meta [Meta], attr: &str) -> Option<&'meta Meta> {
+ let mut curr = extract_meta_lists(&metas, attr);
+ if curr.len() > 1 {
+ panic!(
+ "More than one `{}` attribute found on type, {:?}",
+ attr, curr
+ );
+ }
+
+ curr.pop()
+}
+
+/// Returns the `Ident`s from the `Meta::List`s that match the given `attr` name.
+///
+/// For example, `extract_meta_lists(something_metas, "Something")` returns `Abc`, `Def`, and `Ghi` for
+/// the following declaration.
+///
+/// ```rust,ignore
+/// #[derive(Debug)]
+/// #[strum(Something(Abc, Def), Something(Ghi))]
+/// struct MyStruct {}
+/// ```
+pub fn extract_meta_idents<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Ident> {
+ extract_meta_lists_refs(metas, attr)
+ .into_iter()
.filter_map(|meta| match *meta {
- NestedMeta::Meta(Meta::Word(ref ident)) => Some(ident),
+ Meta::Word(ref ident) => Some(ident),
_ => None,
}).collect()
}
-pub fn unique_meta_attr<'meta>(attrs: &'meta [Meta], attr: &str) -> Option<&'meta Ident> {
- let mut curr = extract_meta_attrs(attrs, attr);
+pub fn unique_meta_ident<'meta>(metas: &[&'meta Meta], attr: &str) -> Option<&'meta Ident> {
+ let mut curr = extract_meta_idents(metas, attr);
if curr.len() > 1 {
- panic!("More than one `{}` attribute found on type", attr);
+ panic!(
+ "More than one `{}` attribute found on type: {:?}",
+ attr, curr
+ );
}
curr.pop()
diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs
index 8f7b4567..560586c2 100644
--- a/strum_macros/src/lib.rs
+++ b/strum_macros/src/lib.rs
@@ -115,10 +115,7 @@ pub fn enum_properties(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
toks.into()
}
-#[proc_macro_derive(
- EnumDiscriminants,
- attributes(strum, strum_discriminants_derive, strum_discriminants_name)
-)]
+#[proc_macro_derive(EnumDiscriminants, attributes(strum, strum_discriminants))]
pub fn enum_discriminants(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse(input).unwrap();
diff --git a/strum_tests/tests/enum_discriminants.rs b/strum_tests/tests/enum_discriminants.rs
index 712fabb6..c081c503 100644
--- a/strum_tests/tests/enum_discriminants.rs
+++ b/strum_tests/tests/enum_discriminants.rs
@@ -6,7 +6,7 @@ use strum::IntoEnumIterator;
#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
-#[strum_discriminants_derive(EnumIter)]
+#[strum_discriminants(derive(EnumIter))]
enum Simple {
Variant0,
Variant1,
@@ -25,7 +25,7 @@ struct NonDefault;
#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
-#[strum_discriminants_derive(EnumIter)]
+#[strum_discriminants(derive(EnumIter))]
enum WithFields {
Variant0(NonDefault),
Variant1 { a: NonDefault },
@@ -47,7 +47,7 @@ trait Bar {}
#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
-#[strum_discriminants_derive(EnumIter)]
+#[strum_discriminants(derive(EnumIter))]
enum Complicated {
/// With Docs
A(U),
@@ -87,8 +87,7 @@ fn with_default_test() {
#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
-#[strum_discriminants_derive(EnumIter)]
-#[strum_discriminants_name(EnumBoo)]
+#[strum_discriminants(name(EnumBoo), derive(EnumIter))]
enum Renamed {
Variant0(bool),
Variant1(i32),