Skip to content

Commit

Permalink
Use single strum_discriminants top level attribute.
Browse files Browse the repository at this point in the history
  • Loading branch information
azriel91 committed Sep 20, 2018
1 parent 18e60da commit c28616d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 63 deletions.
10 changes: 5 additions & 5 deletions README.md
Expand Up @@ -301,7 +301,7 @@ Strum has implemented the following macros:
wish to determine the variant of an enum from a String, but the variants contain any
non-`Default` fields. By default, the generated enum has the followign derives:
`Clone, Copy, Debug, PartialEq, Eq`. You can add additional derives using the
`#[strum_discriminants_derive(AdditionalDerive)]` attribute.
`#[strum_discriminants(derive(AdditionalDerive))]` attribute.

Here's an example:

Expand All @@ -317,7 +317,7 @@ Strum has implemented the following macros:

#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
#[strum_discriminants_derive(EnumString)]
#[strum_discriminants(derive(EnumString))]
enum MyEnum {
Variant0(NonDefault),
Variant1 { a: NonDefault },
Expand All @@ -331,7 +331,7 @@ Strum has implemented the following macros:
}
```

You can also rename the generated enum using the `#[strum_discriminants_name(OtherName)]`
You can also rename the generated enum using the `#[strum_discriminants(name(OtherName))]`
attribute:

```rust
Expand All @@ -342,8 +342,8 @@ Strum has implemented the following macros:

#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
#[strum_discriminants_name(MyVariants)]
#[strum_discriminants(derive(EnumIter))]
#[strum_discriminants(name(MyVariants))]
enum MyEnum {
Variant0(bool),
Variant1 { a: bool },
Expand Down
9 changes: 4 additions & 5 deletions strum/src/lib.rs
Expand Up @@ -295,7 +295,7 @@
//! wish to determine the variant of an enum from a String, but the variants contain any
//! non-`Default` fields. By default, the generated enum has the followign derives:
//! `Clone, Copy, Debug, PartialEq, Eq`. You can add additional derives using the
//! `#[strum_discriminants_derive(AdditionalDerive)]` attribute.
//! `#[strum_discriminants(derive(AdditionalDerive))]` attribute.
//!
//! Here's an example:
//!
Expand All @@ -311,7 +311,7 @@
//!
//! #[allow(dead_code)]
//! #[derive(Debug, EnumDiscriminants)]
//! #[strum_discriminants_derive(EnumString)]
//! #[strum_discriminants(derive(EnumString))]
//! enum MyEnum {
//! Variant0(NonDefault),
//! Variant1 { a: NonDefault },
Expand All @@ -325,7 +325,7 @@
//! }
//! ```
//!
//! You can also rename the generated enum using the `#[strum_discriminants_name(OtherName)]`
//! You can also rename the generated enum using the `#[strum_discriminants(name(OtherName))]`
//! attribute:
//!
//! ```rust
Expand All @@ -336,8 +336,7 @@
//!
//! #[allow(dead_code)]
//! #[derive(Debug, EnumDiscriminants)]
//! #[strum_discriminants_derive(EnumIter)]
//! #[strum_discriminants_name(MyVariants)]
//! #[strum_discriminants(name(MyVariants), derive(EnumIter))]
//! enum MyEnum {
//! Variant0(bool),
//! Variant1 { a: bool },
Expand Down
2 changes: 1 addition & 1 deletion strum_macros/Cargo.toml
Expand Up @@ -20,4 +20,4 @@ name = "strum_macros"
heck = "0.3"
proc-macro2 = "0.4"
quote = "0.6"
syn = { version = "0.15", features = ["parsing"] }
syn = { version = "0.15", features = ["parsing", "extra-traits"] }
26 changes: 20 additions & 6 deletions strum_macros/src/enum_discriminants.rs
@@ -1,7 +1,7 @@
use proc_macro2::{Span, TokenStream};
use syn;

use helpers::{extract_meta, extract_meta_idents, unique_meta_ident, unique_meta_list};
use helpers::{extract_list_metas, extract_meta, get_meta_ident, get_meta_list, unique_meta_list};

pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
Expand All @@ -12,23 +12,37 @@ pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
_ => panic!("EnumDiscriminants only works on Enums"),
};

// Derives for the generated enum
let type_meta = extract_meta(&ast.attrs);
let discriminant_meta = unique_meta_list(&type_meta, "strum_discriminants");
let derives =
discriminant_meta.map_or_else(|| vec![], |meta| extract_meta_idents(&[meta], "derive"));
let discriminant_attrs = unique_meta_list(type_meta.iter(), "strum_discriminants")
.map(|meta| extract_list_metas(meta).collect::<Vec<_>>());
let derives = discriminant_attrs.as_ref().map_or_else(
|| vec![],
|meta| {
get_meta_list(meta.iter().map(|&m| m), "derive")
.flat_map(extract_list_metas)
.filter_map(get_meta_ident)
.collect::<Vec<_>>()
},
);

let derives = quote! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
};

// Work out the name
let default_name = syn::Ident::new(
&format!("{}Discriminants", name.to_string()),
Span::call_site(),
);
let discriminants_name = discriminant_meta
.map(|meta| unique_meta_ident(&[meta], "name").unwrap_or(&default_name))
let discriminants_name = discriminant_attrs
.as_ref()
.and_then(|meta| unique_meta_list(meta.iter().map(|&m| m), "name"))
.map(extract_list_metas)
.and_then(|metas| metas.filter_map(get_meta_ident).next())
.unwrap_or(&default_name);

// Add the variants without fields, but exclude the `strum` meta item
let mut discriminants = Vec::new();
for variant in variants {
let ident = &variant.ident;
Expand Down
83 changes: 37 additions & 46 deletions strum_macros/src/helpers.rs
@@ -1,5 +1,5 @@
use heck::{CamelCase, KebabCase, MixedCase, ShoutySnakeCase, SnakeCase, TitleCase};
use syn::{Attribute, Ident, Meta};
use syn::{Attribute, Ident, Meta, MetaList};

use case_style::CaseStyle;

Expand All @@ -10,83 +10,74 @@ pub fn extract_meta(attrs: &[Attribute]) -> Vec<Meta> {
.collect()
}

/// Returns the `Meta` s that are `List`s, and match the given `attr` name.
/// Returns the `MetaList`s with the given attr name.
///
/// For example, `extract_meta_lists(metas, "strum")` returns both `#[strum(Something)]` and
/// `#[strum(SomethingElse)]` for the following declaration.
/// For example, `get_meta_list(type_meta.iter(), "strum_discriminant")` for the following snippet
/// will return an iterator with `#[strum_discriminant(derive(EnumIter))]` and
/// `#[strum_discriminant(name(MyEnumVariants))]`.
///
/// ```rust,ignore
/// #[derive(Debug)]
/// #[strum(Something)]
/// #[strum(SomethingElse)]
/// struct MyStruct {}
/// #[strum_discriminant(derive(EnumIter))]
/// #[strum_discriminant(name(MyEnumVariants))]
/// enum MyEnum { A }
/// ```
pub fn extract_meta_lists_refs<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Meta> {
use syn::NestedMeta;
pub fn get_meta_list<'meta, MetaIt>(
metas: MetaIt,
attr: &'meta str,
) -> impl Iterator<Item = &'meta MetaList>
where
MetaIt: Iterator<Item = &'meta Meta>,
{
metas
.iter()
// Get all the attributes with our tag on them.
.filter_map(|meta| match *meta {
.filter_map(move |meta| match meta {
Meta::List(ref metalist) => {
if metalist.ident == attr {
Some(&metalist.nested)
Some(metalist)
} else {
None
}
}
_ => None,
}).flat_map(|nested| nested)
.filter_map(|nested| match *nested {
NestedMeta::Meta(ref meta) => Some(meta),
_ => None,
}).collect()
})
}

pub fn extract_meta_lists<'meta>(metas: &'meta [Meta], attr: &str) -> Vec<&'meta Meta> {
extract_meta_lists_refs(&metas.iter().collect::<Vec<_>>(), attr)
}

pub fn unique_meta_list<'meta>(metas: &'meta [Meta], attr: &str) -> Option<&'meta Meta> {
let mut curr = extract_meta_lists(&metas, attr);
pub fn unique_meta_list<'meta, MetaIt>(metas: MetaIt, attr: &'meta str) -> Option<&'meta MetaList>
where
MetaIt: Iterator<Item = &'meta Meta>,
{
let mut curr = get_meta_list(metas.into_iter(), attr).collect::<Vec<_>>();
if curr.len() > 1 {
panic!(
"More than one `{}` attribute found on type, {:?}",
attr, curr
);
panic!("More than one `{}` attribute found on type", attr);
}

curr.pop()
}

pub fn extract_list_metas<'meta>(metalist: &'meta MetaList) -> impl Iterator<Item = &'meta Meta> {
use syn::NestedMeta;
metalist.nested.iter().filter_map(|nested| match *nested {
NestedMeta::Meta(ref meta) => Some(meta),
_ => None,
})
}

/// Returns the `Ident`s from the `Meta::List`s that match the given `attr` name.
///
/// For example, `extract_meta_lists(something_metas, "Something")` returns `Abc`, `Def`, and `Ghi` for
/// For example, `extract_meta_idents(something_metas, "Something")` returns `Abc`, `Def`, and `Ghi` for
/// the following declaration.
///
/// ```rust,ignore
/// #[derive(Debug)]
/// #[strum(Something(Abc, Def), Something(Ghi))]
/// struct MyStruct {}
/// ```
pub fn extract_meta_idents<'meta>(metas: &[&'meta Meta], attr: &str) -> Vec<&'meta Ident> {
extract_meta_lists_refs(metas, attr)
.into_iter()
.filter_map(|meta| match *meta {
Meta::Word(ref ident) => Some(ident),
_ => None,
}).collect()
}

pub fn unique_meta_ident<'meta>(metas: &[&'meta Meta], attr: &str) -> Option<&'meta Ident> {
let mut curr = extract_meta_idents(metas, attr);
if curr.len() > 1 {
panic!(
"More than one `{}` attribute found on type: {:?}",
attr, curr
);
pub fn get_meta_ident<'meta>(meta: &'meta Meta) -> Option<&'meta Ident> {
match *meta {
Meta::Word(ref ident) => Some(ident),
_ => None,
}

curr.pop()
}

pub fn extract_attrs(meta: &[Meta], attr: &str, prop: &str) -> Vec<String> {
Expand Down

0 comments on commit c28616d

Please sign in to comment.