Skip to content

Commit

Permalink
Support case-insensitive EnumString (#157)
Browse files Browse the repository at this point in the history
Fixes #154
  • Loading branch information
ChayimFriedman2 committed May 3, 2021
1 parent 089aec8 commit ca60910
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 8 deletions.
7 changes: 7 additions & 0 deletions strum/src/additional_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
//! );
//! ```
//!
//! You can also apply the `#[strum(ascii_case_insensitive)]` attribute to the enum,
//! and this has the same effect of applying it to every variant.
//!
//! Custom attributes are applied to a variant by adding `#[strum(parameter="value")]` to the variant.
//!
//! - `serialize="..."`: Changes the text that `FromStr()` looks for when parsing a string. This attribute can
Expand All @@ -58,6 +61,10 @@
//!
//! - `disabled`: removes variant from generated code.
//!
//! - `ascii_case_insensitive`: makes the comparison to this variant case insensitive (ASCII only).
//! If the whole enum is marked `ascii_case_insensitive`, you can specify `ascii_case_insensitive = false`
//! to disable case insensitivity on this variant.
//!
//! - `message=".."`: Adds a message to enum variant. This is used in conjunction with the `EnumMessage`
//! trait to associate a message with a variant. If `detailed_message` is not provided,
//! then `message` will also be returned when get_detailed_message() is called.
Expand Down
35 changes: 30 additions & 5 deletions strum_macros/src/helpers/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Attribute, DeriveInput, Ident, LitStr, Path, Token, Variant, Visibility,
Attribute, DeriveInput, Ident, LitBool, LitStr, Path, Token, Variant, Visibility,
};

use super::case_style::CaseStyle;
Expand All @@ -28,28 +28,39 @@ pub mod kw {
custom_keyword!(disabled);
custom_keyword!(default);
custom_keyword!(props);
custom_keyword!(ascii_case_insensitive);
}

pub enum EnumMeta {
SerializeAll {
kw: kw::serialize_all,
case_style: CaseStyle,
},
AsciiCaseInsensitive(kw::ascii_case_insensitive),
}

impl Parse for EnumMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let kw = input.parse::<kw::serialize_all>()?;
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
let lookahead = input.lookahead1();
if lookahead.peek(kw::serialize_all) {
let kw = input.parse::<kw::serialize_all>()?;
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
Ok(EnumMeta::AsciiCaseInsensitive(kw))
} else {
Err(lookahead.error())
}
}
}

impl Spanned for EnumMeta {
fn span(&self) -> Span {
match self {
EnumMeta::SerializeAll { kw, .. } => kw.span(),
EnumMeta::AsciiCaseInsensitive(kw) => kw.span(),
}
}
}
Expand Down Expand Up @@ -142,6 +153,10 @@ pub enum VariantMeta {
},
Disabled(kw::disabled),
Default(kw::default),
AsciiCaseInsensitive {
kw: kw::ascii_case_insensitive,
value: bool,
},
Props {
kw: kw::props,
props: Vec<(LitStr, LitStr)>,
Expand Down Expand Up @@ -175,6 +190,15 @@ impl Parse for VariantMeta {
Ok(VariantMeta::Disabled(input.parse()?))
} else if lookahead.peek(kw::default) {
Ok(VariantMeta::Default(input.parse()?))
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
let value = if input.peek(Token![=]) {
let _: Token![=] = input.parse()?;
input.parse::<LitBool>()?.value()
} else {
true
};
Ok(VariantMeta::AsciiCaseInsensitive { kw, value })
} else if lookahead.peek(kw::props) {
let kw = input.parse()?;
let content;
Expand Down Expand Up @@ -216,6 +240,7 @@ impl Spanned for VariantMeta {
VariantMeta::ToString { kw, .. } => kw.span,
VariantMeta::Disabled(kw) => kw.span,
VariantMeta::Default(kw) => kw.span,
VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span,
VariantMeta::Props { kw, .. } => kw.span,
}
}
Expand Down
10 changes: 10 additions & 0 deletions strum_macros/src/helpers/type_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub trait HasTypeProperties {
#[derive(Debug, Clone, Default)]
pub struct StrumTypeProperties {
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub discriminant_derives: Vec<Path>,
pub discriminant_name: Option<Ident>,
pub discriminant_others: Vec<TokenStream>,
Expand All @@ -28,6 +29,7 @@ impl HasTypeProperties for DeriveInput {
let discriminants_meta = self.get_discriminants_metadata()?;

let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in strum_meta {
match meta {
EnumMeta::SerializeAll { case_style, kw } => {
Expand All @@ -38,6 +40,14 @@ impl HasTypeProperties for DeriveInput {
serialize_all_kw = Some(kw);
output.case_style = Some(case_style);
}
EnumMeta::AsciiCaseInsensitive(kw) => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}

ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = true;
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions strum_macros/src/helpers/variant_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub trait HasStrumVariantProperties {
pub struct StrumVariantProperties {
pub disabled: Option<kw::disabled>,
pub default: Option<kw::default>,
pub ascii_case_insensitive: Option<bool>,
pub message: Option<LitStr>,
pub detailed_message: Option<LitStr>,
pub string_props: Vec<(LitStr, LitStr)>,
Expand Down Expand Up @@ -65,6 +66,7 @@ impl HasStrumVariantProperties for Variant {
let mut to_string_kw = None;
let mut disabled_kw = None;
let mut default_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in self.get_metadata()? {
match meta {
VariantMeta::Message { value, kw } => {
Expand Down Expand Up @@ -110,6 +112,14 @@ impl HasStrumVariantProperties for Variant {
default_kw = Some(kw);
output.default = Some(kw);
}
VariantMeta::AsciiCaseInsensitive { kw, value } => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}

ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = Some(value);
}
VariantMeta::Props { props, .. } => {
output.string_props.extend(props);
}
Expand Down
10 changes: 9 additions & 1 deletion strum_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// // Notice that we can disable certain variants from being found
/// #[strum(disabled)]
/// Yellow,
///
/// // We can make the comparison case insensitive (however Unicode is not supported at the moment)
/// #[strum(ascii_case_insensitive)]
/// Black,
/// }
///
/// /*
Expand All @@ -77,7 +81,9 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// match s {
/// "Red" => ::std::result::Result::Ok(Color::Red),
/// "Green" => ::std::result::Result::Ok(Color::Green { range:Default::default() }),
/// "blue" | "b" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// "blue" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// "b" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// s if s.eq_ignore_ascii_case("Black") => ::std::result::Result::Ok(Color::Black),
/// _ => ::std::result::Result::Err(::strum::ParseError::VariantNotFound),
/// }
/// }
Expand All @@ -95,6 +101,8 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// assert!(color_variant.is_err());
/// // however the variant is still normally usable
/// println!("{:?}", Color::Yellow);
/// let color_variant = Color::from_str("bLACk").unwrap();
/// assert_eq!(Color::Black, color_variant);
/// ```
#[proc_macro_derive(EnumString, attributes(strum))]
pub fn from_string(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand Down
16 changes: 14 additions & 2 deletions strum_macros/src/macros/strings/from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,20 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
continue;
}

let is_ascii_case_insensitive = variant_properties
.ascii_case_insensitive
.unwrap_or(type_properties.ascii_case_insensitive);
// If we don't have any custom variants, add the default serialized name.
let attrs = variant_properties.get_serializations(type_properties.case_style);
let attrs = variant_properties
.get_serializations(type_properties.case_style)
.into_iter()
.map(|serialization| {
if is_ascii_case_insensitive {
quote! { s if s.eq_ignore_ascii_case(#serialization) }
} else {
quote! { #serialization }
}
});

let params = match &variant.fields {
Fields::Unit => quote! {},
Expand All @@ -69,7 +81,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
};

arms.push(quote! { #(#attrs)|* => ::std::result::Result::Ok(#name::#ident #params) });
arms.push(quote! { #(#attrs => ::std::result::Result::Ok(#name::#ident #params)),* });
}

arms.push(default);
Expand Down
47 changes: 47 additions & 0 deletions strum_tests/tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ enum Color {
Green(String),
#[strum(to_string = "purp")]
Purple,
#[strum(serialize = "blk", serialize = "Black", ascii_case_insensitive)]
Black,
}

#[test]
Expand Down Expand Up @@ -44,6 +46,12 @@ fn color_default() {
);
}

#[test]
fn color_ascii_case_insensitive() {
assert_eq!(Color::Black, Color::from_str("BLK").unwrap());
assert_eq!(Color::Black, Color::from_str("bLaCk").unwrap());
}

#[derive(Debug, Eq, PartialEq, EnumString)]
#[strum(serialize_all = "snake_case")]
enum Brightness {
Expand Down Expand Up @@ -122,3 +130,42 @@ enum Generic<T: Default> {
fn generic_test() {
assert_eq!(Generic::Gen(""), Generic::from_str("Gen").unwrap());
}

#[derive(Debug, Eq, PartialEq, EnumString)]
#[strum(ascii_case_insensitive)]
enum CaseInsensitiveEnum {
NoAttr,
#[strum(ascii_case_insensitive = false)]
NoCaseInsensitive,
#[strum(ascii_case_insensitive = true)]
CaseInsensitive,
}

#[test]
fn case_insensitive_enum_no_attr() {
assert_eq!(
CaseInsensitiveEnum::NoAttr,
CaseInsensitiveEnum::from_str("noattr").unwrap()
);
}

#[test]
fn case_insensitive_enum_no_case_insensitive() {
assert_eq!(
CaseInsensitiveEnum::NoCaseInsensitive,
CaseInsensitiveEnum::from_str("NoCaseInsensitive").unwrap(),
);
assert!(CaseInsensitiveEnum::from_str("nocaseinsensitive").is_err());
}

#[test]
fn case_insensitive_enum_case_insensitive() {
assert_eq!(
CaseInsensitiveEnum::CaseInsensitive,
CaseInsensitiveEnum::from_str("CaseInsensitive").unwrap(),
);
assert_eq!(
CaseInsensitiveEnum::CaseInsensitive,
CaseInsensitiveEnum::from_str("caseinsensitive").unwrap(),
);
}

0 comments on commit ca60910

Please sign in to comment.