Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support case-insensitive EnumString #157

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions strum/src/additional_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
//! );
//! ```
//!
//! You can also apply the `#[strum(ascii_case_insensitive)]` attribute to the enum,
//! and this has the same effect of applying it to every variant.
//!
//! Custom attributes are applied to a variant by adding `#[strum(parameter="value")]` to the variant.
//!
//! - `serialize="..."`: Changes the text that `FromStr()` looks for when parsing a string. This attribute can
Expand All @@ -58,6 +61,10 @@
//!
//! - `disabled`: removes variant from generated code.
//!
//! - `ascii_case_insensitive`: makes the comparison to this variant case insensitive (ASCII only).
//! If the whole enum is marked `ascii_case_insensitive`, you can specify `ascii_case_insensitive = false`
//! to disable case insensitivity on this variant.
//!
//! - `message=".."`: Adds a message to enum variant. This is used in conjunction with the `EnumMessage`
//! trait to associate a message with a variant. If `detailed_message` is not provided,
//! then `message` will also be returned when get_detailed_message() is called.
Expand Down
35 changes: 30 additions & 5 deletions strum_macros/src/helpers/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Attribute, DeriveInput, Ident, LitStr, Path, Token, Variant, Visibility,
Attribute, DeriveInput, Ident, LitBool, LitStr, Path, Token, Variant, Visibility,
};

use super::case_style::CaseStyle;
Expand All @@ -28,28 +28,39 @@ pub mod kw {
custom_keyword!(disabled);
custom_keyword!(default);
custom_keyword!(props);
custom_keyword!(ascii_case_insensitive);
}

pub enum EnumMeta {
SerializeAll {
kw: kw::serialize_all,
case_style: CaseStyle,
},
AsciiCaseInsensitive(kw::ascii_case_insensitive),
}

impl Parse for EnumMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let kw = input.parse::<kw::serialize_all>()?;
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
let lookahead = input.lookahead1();
if lookahead.peek(kw::serialize_all) {
let kw = input.parse::<kw::serialize_all>()?;
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
Ok(EnumMeta::AsciiCaseInsensitive(kw))
} else {
Err(lookahead.error())
}
}
}

impl Spanned for EnumMeta {
fn span(&self) -> Span {
match self {
EnumMeta::SerializeAll { kw, .. } => kw.span(),
EnumMeta::AsciiCaseInsensitive(kw) => kw.span(),
}
}
}
Expand Down Expand Up @@ -142,6 +153,10 @@ pub enum VariantMeta {
},
Disabled(kw::disabled),
Default(kw::default),
AsciiCaseInsensitive {
kw: kw::ascii_case_insensitive,
value: bool,
},
Props {
kw: kw::props,
props: Vec<(LitStr, LitStr)>,
Expand Down Expand Up @@ -175,6 +190,15 @@ impl Parse for VariantMeta {
Ok(VariantMeta::Disabled(input.parse()?))
} else if lookahead.peek(kw::default) {
Ok(VariantMeta::Default(input.parse()?))
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
let value = if input.peek(Token![=]) {
let _: Token![=] = input.parse()?;
input.parse::<LitBool>()?.value()
} else {
true
};
Ok(VariantMeta::AsciiCaseInsensitive { kw, value })
} else if lookahead.peek(kw::props) {
let kw = input.parse()?;
let content;
Expand Down Expand Up @@ -216,6 +240,7 @@ impl Spanned for VariantMeta {
VariantMeta::ToString { kw, .. } => kw.span,
VariantMeta::Disabled(kw) => kw.span,
VariantMeta::Default(kw) => kw.span,
VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span,
VariantMeta::Props { kw, .. } => kw.span,
}
}
Expand Down
10 changes: 10 additions & 0 deletions strum_macros/src/helpers/type_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub trait HasTypeProperties {
#[derive(Debug, Clone, Default)]
pub struct StrumTypeProperties {
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub discriminant_derives: Vec<Path>,
pub discriminant_name: Option<Ident>,
pub discriminant_others: Vec<TokenStream>,
Expand All @@ -28,6 +29,7 @@ impl HasTypeProperties for DeriveInput {
let discriminants_meta = self.get_discriminants_metadata()?;

let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in strum_meta {
match meta {
EnumMeta::SerializeAll { case_style, kw } => {
Expand All @@ -38,6 +40,14 @@ impl HasTypeProperties for DeriveInput {
serialize_all_kw = Some(kw);
output.case_style = Some(case_style);
}
EnumMeta::AsciiCaseInsensitive(kw) => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}

ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = true;
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions strum_macros/src/helpers/variant_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub trait HasStrumVariantProperties {
pub struct StrumVariantProperties {
pub disabled: Option<kw::disabled>,
pub default: Option<kw::default>,
pub ascii_case_insensitive: Option<bool>,
pub message: Option<LitStr>,
pub detailed_message: Option<LitStr>,
pub string_props: Vec<(LitStr, LitStr)>,
Expand Down Expand Up @@ -65,6 +66,7 @@ impl HasStrumVariantProperties for Variant {
let mut to_string_kw = None;
let mut disabled_kw = None;
let mut default_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in self.get_metadata()? {
match meta {
VariantMeta::Message { value, kw } => {
Expand Down Expand Up @@ -110,6 +112,14 @@ impl HasStrumVariantProperties for Variant {
default_kw = Some(kw);
output.default = Some(kw);
}
VariantMeta::AsciiCaseInsensitive { kw, value } => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
}

ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = Some(value);
}
VariantMeta::Props { props, .. } => {
output.string_props.extend(props);
}
Expand Down
10 changes: 9 additions & 1 deletion strum_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// // Notice that we can disable certain variants from being found
/// #[strum(disabled)]
/// Yellow,
///
/// // We can make the comparison case insensitive (however Unicode is not supported at the moment)
/// #[strum(ascii_case_insensitive)]
/// Black,
/// }
///
/// /*
Expand All @@ -77,7 +81,9 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// match s {
/// "Red" => ::std::result::Result::Ok(Color::Red),
/// "Green" => ::std::result::Result::Ok(Color::Green { range:Default::default() }),
/// "blue" | "b" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// "blue" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// "b" => ::std::result::Result::Ok(Color::Blue(Default::default())),
/// s if s.eq_ignore_ascii_case("Black") => ::std::result::Result::Ok(Color::Black),
/// _ => ::std::result::Result::Err(::strum::ParseError::VariantNotFound),
/// }
/// }
Expand All @@ -95,6 +101,8 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) {
/// assert!(color_variant.is_err());
/// // however the variant is still normally usable
/// println!("{:?}", Color::Yellow);
/// let color_variant = Color::from_str("bLACk").unwrap();
/// assert_eq!(Color::Black, color_variant);
/// ```
#[proc_macro_derive(EnumString, attributes(strum))]
pub fn from_string(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand Down
16 changes: 14 additions & 2 deletions strum_macros/src/macros/strings/from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,20 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
continue;
}

let is_ascii_case_insensitive = variant_properties
.ascii_case_insensitive
.unwrap_or(type_properties.ascii_case_insensitive);
// If we don't have any custom variants, add the default serialized name.
let attrs = variant_properties.get_serializations(type_properties.case_style);
let attrs = variant_properties
.get_serializations(type_properties.case_style)
.into_iter()
.map(|serialization| {
if is_ascii_case_insensitive {
quote! { s if s.eq_ignore_ascii_case(#serialization) }
} else {
quote! { #serialization }
}
});

let params = match &variant.fields {
Fields::Unit => quote! {},
Expand All @@ -69,7 +81,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
};

arms.push(quote! { #(#attrs)|* => ::std::result::Result::Ok(#name::#ident #params) });
arms.push(quote! { #(#attrs => ::std::result::Result::Ok(#name::#ident #params)),* });
}

arms.push(default);
Expand Down
47 changes: 47 additions & 0 deletions strum_tests/tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ enum Color {
Green(String),
#[strum(to_string = "purp")]
Purple,
#[strum(serialize = "blk", serialize = "Black", ascii_case_insensitive)]
Black,
}

#[test]
Expand Down Expand Up @@ -44,6 +46,12 @@ fn color_default() {
);
}

#[test]
fn color_ascii_case_insensitive() {
assert_eq!(Color::Black, Color::from_str("BLK").unwrap());
assert_eq!(Color::Black, Color::from_str("bLaCk").unwrap());
}

#[derive(Debug, Eq, PartialEq, EnumString)]
#[strum(serialize_all = "snake_case")]
enum Brightness {
Expand Down Expand Up @@ -122,3 +130,42 @@ enum Generic<T: Default> {
fn generic_test() {
assert_eq!(Generic::Gen(""), Generic::from_str("Gen").unwrap());
}

#[derive(Debug, Eq, PartialEq, EnumString)]
#[strum(ascii_case_insensitive)]
enum CaseInsensitiveEnum {
NoAttr,
#[strum(ascii_case_insensitive = false)]
NoCaseInsensitive,
#[strum(ascii_case_insensitive = true)]
CaseInsensitive,
}

#[test]
fn case_insensitive_enum_no_attr() {
assert_eq!(
CaseInsensitiveEnum::NoAttr,
CaseInsensitiveEnum::from_str("noattr").unwrap()
);
}

#[test]
fn case_insensitive_enum_no_case_insensitive() {
assert_eq!(
CaseInsensitiveEnum::NoCaseInsensitive,
CaseInsensitiveEnum::from_str("NoCaseInsensitive").unwrap(),
);
assert!(CaseInsensitiveEnum::from_str("nocaseinsensitive").is_err());
}

#[test]
fn case_insensitive_enum_case_insensitive() {
assert_eq!(
CaseInsensitiveEnum::CaseInsensitive,
CaseInsensitiveEnum::from_str("CaseInsensitive").unwrap(),
);
assert_eq!(
CaseInsensitiveEnum::CaseInsensitive,
CaseInsensitiveEnum::from_str("caseinsensitive").unwrap(),
);
}