From 26ebfbd22c237a658784deefa37097992a3c63b9 Mon Sep 17 00:00:00 2001 From: Azriel Hoh Date: Wed, 19 Sep 2018 17:19:30 +1200 Subject: [PATCH] Added `EnumDiscriminants` derive. Issue #33 --- strum_macros/src/enum_discriminants.rs | 47 ++++++++++++++ strum_macros/src/helpers.rs | 21 ++++++ strum_macros/src/lib.rs | 13 ++++ strum_tests/tests/enum_discriminants.rs | 86 +++++++++++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 strum_macros/src/enum_discriminants.rs create mode 100644 strum_tests/tests/enum_discriminants.rs diff --git a/strum_macros/src/enum_discriminants.rs b/strum_macros/src/enum_discriminants.rs new file mode 100644 index 00000000..91440ef1 --- /dev/null +++ b/strum_macros/src/enum_discriminants.rs @@ -0,0 +1,47 @@ +use proc_macro2::{Span, TokenStream}; +use syn; + +use helpers::{extract_meta, extract_meta_attrs}; + +pub fn enum_discriminants_inner(ast: &syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + let vis = &ast.vis; + let enum_discriminants_name = syn::Ident::new( + &format!("{}Discriminants", name.to_string()), + Span::call_site(), + ); + + let variants = match ast.data { + syn::Data::Enum(ref v) => &v.variants, + _ => panic!("EnumDiscriminants only works on Enums"), + }; + + let type_meta = extract_meta(&ast.attrs); + let discriminant_derives = extract_meta_attrs(&type_meta, "strum_discriminants_derive"); + let derives = quote! { + #[derive(Clone, Copy, Debug, PartialEq, Eq, #(#discriminant_derives),*)] + }; + + let mut discriminants = Vec::new(); + for variant in variants { + let ident = &variant.ident; + + // Don't copy across the "strum" meta attribute. + let attrs = variant.attrs.iter().filter(|attr| { + attr.interpret_meta().map_or(true, |meta| match meta { + syn::Meta::List(ref metalist) => metalist.ident != "strum", + _ => true, + }) + }); + + discriminants.push(quote!{ #(#attrs)* #ident }); + } + + quote!{ + /// Auto-generated discriminant enum variants + #derives + #vis enum #enum_discriminants_name { + #(#discriminants),* + } + } +} diff --git a/strum_macros/src/helpers.rs b/strum_macros/src/helpers.rs index df4e0da9..a62aa756 100644 --- a/strum_macros/src/helpers.rs +++ b/strum_macros/src/helpers.rs @@ -10,6 +10,27 @@ pub fn extract_meta(attrs: &[Attribute]) -> Vec { .collect() } +pub fn extract_meta_attrs<'meta>(meta: &'meta [Meta], attr: &str) -> Vec<&'meta Ident> { + use syn::NestedMeta; + meta.iter() + // Get all the attributes with our tag on them. + .filter_map(|meta| match *meta { + Meta::List(ref metalist) => { + if metalist.ident == attr { + Some(&metalist.nested) + } else { + None + } + } + _ => None, + }).flat_map(|nested| nested) + // Get all the inner elements as long as they start with ser. + .filter_map(|meta| match *meta { + NestedMeta::Meta(Meta::Word(ref ident)) => Some(ident), + _ => None, + }).collect() +} + pub fn extract_attrs(meta: &[Meta], attr: &str, prop: &str) -> Vec { use syn::{Lit, MetaNameValue, NestedMeta}; meta.iter() diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index c52d0fa7..a48e79ba 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -19,6 +19,7 @@ extern crate proc_macro2; mod as_ref_str; mod case_style; mod display; +mod enum_discriminants; mod enum_iter; mod enum_messages; mod enum_properties; @@ -113,3 +114,15 @@ pub fn enum_properties(input: proc_macro::TokenStream) -> proc_macro::TokenStrea debug_print_generated(&ast, &toks); toks.into() } + +#[proc_macro_derive( + EnumDiscriminants, + attributes(strum, strum_discriminants_derive) +)] +pub fn enum_discriminants(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse(input).unwrap(); + + let toks = enum_discriminants::enum_discriminants_inner(&ast); + debug_print_generated(&ast, &toks); + toks.into() +} diff --git a/strum_tests/tests/enum_discriminants.rs b/strum_tests/tests/enum_discriminants.rs new file mode 100644 index 00000000..d635e31c --- /dev/null +++ b/strum_tests/tests/enum_discriminants.rs @@ -0,0 +1,86 @@ +extern crate strum; +#[macro_use] +extern crate strum_macros; + +use strum::IntoEnumIterator; + +#[allow(dead_code)] +#[derive(Debug, Eq, PartialEq, EnumDiscriminants)] +#[strum_discriminants_derive(EnumIter)] +enum Simple { + Variant0, + Variant1, +} + +#[test] +fn simple_test() { + let discriminants = SimpleDiscriminants::iter().collect::>(); + let expected = vec![SimpleDiscriminants::Variant0, SimpleDiscriminants::Variant1]; + + assert_eq!(expected, discriminants); +} + +#[derive(Debug)] +struct NonDefault; + +#[allow(dead_code)] +#[derive(Debug, EnumDiscriminants)] +#[strum_discriminants_derive(EnumIter)] +enum WithFields { + Variant0(NonDefault), + Variant1 { a: NonDefault }, +} + +#[test] +fn fields_test() { + let discriminants = WithFieldsDiscriminants::iter().collect::>(); + let expected = vec![ + WithFieldsDiscriminants::Variant0, + WithFieldsDiscriminants::Variant1, + ]; + + assert_eq!(expected, discriminants); +} + +trait Foo {} +trait Bar {} + +#[allow(dead_code)] +#[derive(Debug, Eq, PartialEq, EnumDiscriminants)] +#[strum_discriminants_derive(EnumIter)] +enum Complicated { + /// With Docs + A(U), + B { + v: V, + }, + C, +} + +#[test] +fn complicated_test() { + let discriminants = ComplicatedDiscriminants::iter().collect::>(); + let expected = vec![ + ComplicatedDiscriminants::A, + ComplicatedDiscriminants::B, + ComplicatedDiscriminants::C, + ]; + + assert_eq!(expected, discriminants); +} + +// This test exists to ensure that we do not copy across the `#[strum(default = "true")]` meta +// attribute. If we do without deriving any `strum` derivations on the generated discriminant enum, +// Rust will generate a compiler error saying it doesn't understand the `strum` attribute. +#[allow(dead_code)] +#[derive(Debug, EnumDiscriminants)] +enum WithDefault { + #[strum(default = "true")] + A(String), + B, +} + +#[test] +fn with_default_test() { + assert!(WithDefaultDiscriminants::A != WithDefaultDiscriminants::B); +}