diff --git a/strum_macros/src/as_ref_str.rs b/strum_macros/src/as_ref_str.rs index e8249807..b22525c5 100644 --- a/strum_macros/src/as_ref_str.rs +++ b/strum_macros/src/as_ref_str.rs @@ -1,7 +1,8 @@ use proc_macro2::TokenStream; use syn; -use helpers::{extract_attrs, extract_meta, is_disabled, unique_attr}; +use case_style::CaseStyle; +use helpers::{convert_case, extract_attrs, extract_meta, is_disabled, unique_attr}; fn get_arms(ast: &syn::DeriveInput) -> Vec { let name = &ast.ident; @@ -11,6 +12,10 @@ fn get_arms(ast: &syn::DeriveInput) -> Vec { _ => panic!("This macro only works on Enums"), }; + let type_meta = extract_meta(&ast.attrs); + let case_style = unique_attr(&type_meta, "strum", "serialize_all") + .map(|style| CaseStyle::from(style.as_ref())); + for variant in variants { use syn::Fields::*; let ident = &variant.ident; @@ -32,23 +37,23 @@ fn get_arms(ast: &syn::DeriveInput) -> Vec { if let Some(n) = attrs.pop() { n } else { - ident.to_string() + convert_case(ident, case_style) } }; let params = match variant.fields { - Unit => quote!{}, - Unnamed(..) => quote!{ (..) }, - Named(..) => quote!{ {..} }, + Unit => quote! {}, + Unnamed(..) => quote! { (..) }, + Named(..) => quote! { {..} }, }; - arms.push(quote!{ #name::#ident #params => #output }); + arms.push(quote! { #name::#ident #params => #output }); } if arms.len() < variants.len() { - arms.push(quote!{ - _ => panic!("AsRef::::as_ref() or AsStaticRef::::as_static() \ - called on disabled variant.") + arms.push(quote! { + _ => panic!("AsRef::::as_ref() or AsStaticRef::::as_static() \ + called on disabled variant.") }) } @@ -59,7 +64,7 @@ pub fn as_ref_str_inner(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); let arms = get_arms(ast); - quote!{ + quote! { impl #impl_generics ::std::convert::AsRef for #name #ty_generics #where_clause { fn as_ref(&self) -> &str { match *self { @@ -94,7 +99,7 @@ pub fn as_static_str_inner( let arms3 = arms.clone(); match trait_variant { GenerateTraitVariant::AsStaticStr => { - quote!{ + quote! { impl #impl_generics ::strum::AsStaticRef for #name #ty_generics #where_clause { fn as_static(&self) -> &'static str { match *self { diff --git a/strum_tests/tests/as_ref_str.rs b/strum_tests/tests/as_ref_str.rs index 0262f3a3..e0fa77ea 100644 --- a/strum_tests/tests/as_ref_str.rs +++ b/strum_tests/tests/as_ref_str.rs @@ -85,3 +85,29 @@ fn test_into_static_str() { assert_eq!("B", <&'static str>::from(Moo::B::)); assert_eq!("C", <&'static str>::from(Moo::C::(&17))); } + +#[derive(Debug, Eq, PartialEq, AsRefStr, AsStaticStr, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] +enum Brightness { + DarkBlack, + Dim { + glow: usize, + }, + #[strum(serialize = "Bright")] + BrightWhite, +} + +#[test] +fn brightness_serialize_all() { + assert_eq!("dark_black", Brightness::DarkBlack.as_ref()); + assert_eq!("dim", Brightness::Dim { glow: 0 }.as_ref()); + assert_eq!("Bright", Brightness::BrightWhite.as_ref()); + + assert_eq!("dark_black", Brightness::DarkBlack.as_static()); + assert_eq!("dim", Brightness::Dim { glow: 0 }.as_static()); + assert_eq!("Bright", Brightness::BrightWhite.as_static()); + + assert_eq!("dark_black", <&'static str>::from(Brightness::DarkBlack)); + assert_eq!("dim", <&'static str>::from(Brightness::Dim { glow: 0 })); + assert_eq!("Bright", <&'static str>::from(Brightness::BrightWhite)); +}