Skip to content

Commit

Permalink
Added EnumDiscriminants derive.
Browse files Browse the repository at this point in the history
  • Loading branch information
azriel91 committed Sep 19, 2018
1 parent d0cfdac commit 26ebfbd
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 0 deletions.
47 changes: 47 additions & 0 deletions strum_macros/src/enum_discriminants.rs
@@ -0,0 +1,47 @@
use proc_macro2::{Span, TokenStream};
use syn;

use helpers::{extract_meta, extract_meta_attrs};

pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let vis = &ast.vis;
let enum_discriminants_name = syn::Ident::new(
&format!("{}Discriminants", name.to_string()),
Span::call_site(),
);

let variants = match ast.data {
syn::Data::Enum(ref v) => &v.variants,
_ => panic!("EnumDiscriminants only works on Enums"),
};

let type_meta = extract_meta(&ast.attrs);
let discriminant_derives = extract_meta_attrs(&type_meta, "strum_discriminants_derive");
let derives = quote! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, #(#discriminant_derives),*)]
};

let mut discriminants = Vec::new();
for variant in variants {
let ident = &variant.ident;

// Don't copy across the "strum" meta attribute.
let attrs = variant.attrs.iter().filter(|attr| {
attr.interpret_meta().map_or(true, |meta| match meta {
syn::Meta::List(ref metalist) => metalist.ident != "strum",
_ => true,
})
});

discriminants.push(quote!{ #(#attrs)* #ident });
}

quote!{
/// Auto-generated discriminant enum variants
#derives
#vis enum #enum_discriminants_name {
#(#discriminants),*
}
}
}
21 changes: 21 additions & 0 deletions strum_macros/src/helpers.rs
Expand Up @@ -10,6 +10,27 @@ pub fn extract_meta(attrs: &[Attribute]) -> Vec<Meta> {
.collect()
}

pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta Ident> {
use syn::NestedMeta;
meta.iter()
// Get all the attributes with our tag on them.
.filter_map(|meta| match *meta {
Meta::List(ref metalist) => {
if metalist.ident == attr {
Some(&metalist.nested)
} else {
None
}
}
_ => None,
}).flat_map(|nested| nested)
// Get all the inner elements as long as they start with ser.
.filter_map(|meta| match *meta {
NestedMeta::Meta(Meta::Word(ref ident)) => Some(ident),
_ => None,
}).collect()
}

pub fn extract_attrs(meta: &[Meta], attr: &str, prop: &str) -> Vec<String> {
use syn::{Lit, MetaNameValue, NestedMeta};
meta.iter()
Expand Down
13 changes: 13 additions & 0 deletions strum_macros/src/lib.rs
Expand Up @@ -19,6 +19,7 @@ extern crate proc_macro2;
mod as_ref_str;
mod case_style;
mod display;
mod enum_discriminants;
mod enum_iter;
mod enum_messages;
mod enum_properties;
Expand Down Expand Up @@ -113,3 +114,15 @@ pub fn enum_properties(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
debug_print_generated(&ast, &toks);
toks.into()
}

#[proc_macro_derive(
EnumDiscriminants,
attributes(strum, strum_discriminants_derive)
)]
pub fn enum_discriminants(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse(input).unwrap();

let toks = enum_discriminants::enum_discriminants_inner(&ast);
debug_print_generated(&ast, &toks);
toks.into()
}
86 changes: 86 additions & 0 deletions strum_tests/tests/enum_discriminants.rs
@@ -0,0 +1,86 @@
extern crate strum;
#[macro_use]
extern crate strum_macros;

use strum::IntoEnumIterator;

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
enum Simple {
Variant0,
Variant1,
}

#[test]
fn simple_test() {
let discriminants = SimpleDiscriminants::iter().collect::<Vec<_>>();
let expected = vec![SimpleDiscriminants::Variant0, SimpleDiscriminants::Variant1];

assert_eq!(expected, discriminants);
}

#[derive(Debug)]
struct NonDefault;

#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
enum WithFields {
Variant0(NonDefault),
Variant1 { a: NonDefault },
}

#[test]
fn fields_test() {
let discriminants = WithFieldsDiscriminants::iter().collect::<Vec<_>>();
let expected = vec![
WithFieldsDiscriminants::Variant0,
WithFieldsDiscriminants::Variant1,
];

assert_eq!(expected, discriminants);
}

trait Foo {}
trait Bar {}

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants_derive(EnumIter)]
enum Complicated<U: Foo, V: Bar> {
/// With Docs
A(U),
B {
v: V,
},
C,
}

#[test]
fn complicated_test() {
let discriminants = ComplicatedDiscriminants::iter().collect::<Vec<_>>();
let expected = vec![
ComplicatedDiscriminants::A,
ComplicatedDiscriminants::B,
ComplicatedDiscriminants::C,
];

assert_eq!(expected, discriminants);
}

// This test exists to ensure that we do not copy across the `#[strum(default = "true")]` meta
// attribute. If we do without deriving any `strum` derivations on the generated discriminant enum,
// Rust will generate a compiler error saying it doesn't understand the `strum` attribute.
#[allow(dead_code)]
#[derive(Debug, EnumDiscriminants)]
enum WithDefault {
#[strum(default = "true")]
A(String),
B,
}

#[test]
fn with_default_test() {
assert!(WithDefaultDiscriminants::A != WithDefaultDiscriminants::B);
}

0 comments on commit 26ebfbd

Please sign in to comment.