diff --git a/CHANGELOG.md b/CHANGELOG.md index 6215afc8..050bf40d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +* [#220](https://github.com/Peternator7/strum/pull/220). Add support for PHF in `EnumString` (opt-in runtime + performance improvements for large enums as `#[strum(use_phf)]`, requires `phf` feature and increases MSRV to `1.46`) + ## 0.24.0 * [#212](https://github.com/Peternator7/strum/pull/212). Fix some clippy lints diff --git a/strum/Cargo.toml b/strum/Cargo.toml index 276c8fad..ca4dd8f3 100644 --- a/strum/Cargo.toml +++ b/strum/Cargo.toml @@ -16,6 +16,7 @@ readme = "../README.md" [dependencies] strum_macros = { path = "../strum_macros", optional = true, version = "0.24" } +phf = { version = "0.10", features = ["macros"], optional = true } [dev-dependencies] strum_macros = { path = "../strum_macros", version = "0.24" } diff --git a/strum/src/lib.rs b/strum/src/lib.rs index 0ce6e532..2711912c 100644 --- a/strum/src/lib.rs +++ b/strum/src/lib.rs @@ -30,6 +30,10 @@ // only for documentation purposes pub mod additional_attributes; +#[cfg(feature = "phf")] +#[doc(hidden)] +pub use phf as _private_phf_reexport_for_macro_if_phf_feature; + /// The `ParseError` enum is a collection of all the possible reasons /// an enum can fail to parse from a string. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] diff --git a/strum_macros/src/helpers/metadata.rs b/strum_macros/src/helpers/metadata.rs index 08ca3847..56e4c78b 100644 --- a/strum_macros/src/helpers/metadata.rs +++ b/strum_macros/src/helpers/metadata.rs @@ -5,7 +5,8 @@ use syn::{ parse2, parse_str, punctuated::Punctuated, spanned::Spanned, - Attribute, DeriveInput, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token, Variant, Visibility, + Attribute, DeriveInput, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token, Variant, + Visibility, }; use super::case_style::CaseStyle; @@ -16,6 +17,7 @@ pub mod kw { // enum metadata custom_keyword!(serialize_all); + custom_keyword!(use_phf); // enum discriminant metadata custom_keyword!(derive); @@ -43,6 +45,7 @@ pub enum EnumMeta { kw: kw::Crate, crate_module_path: Path, }, + UsePhf(kw::use_phf), } impl Parse for EnumMeta { @@ -64,8 +67,9 @@ impl Parse for EnumMeta { crate_module_path, }) } else if lookahead.peek(kw::ascii_case_insensitive) { - let kw = input.parse()?; - Ok(EnumMeta::AsciiCaseInsensitive(kw)) + Ok(EnumMeta::AsciiCaseInsensitive(input.parse()?)) + } else if lookahead.peek(kw::use_phf) { + Ok(EnumMeta::UsePhf(input.parse()?)) } else { Err(lookahead.error()) } @@ -78,6 +82,7 @@ impl Spanned for EnumMeta { EnumMeta::SerializeAll { kw, .. } => kw.span(), EnumMeta::AsciiCaseInsensitive(kw) => kw.span(), EnumMeta::Crate { kw, .. } => kw.span(), + EnumMeta::UsePhf(use_phf) => use_phf.span(), } } } @@ -275,14 +280,19 @@ pub trait VariantExt { impl VariantExt for Variant { fn get_metadata(&self) -> syn::Result> { let result = get_metadata_inner("strum", &self.attrs)?; - self.attrs.iter() + self.attrs + .iter() .filter(|attr| attr.path.is_ident("doc")) .try_fold(result, |mut vec, attr| { - if let Meta::NameValue(MetaNameValue { lit: Lit::Str(value), .. }) = attr.parse_meta()? { - vec.push(VariantMeta::Documentation { value }) - } - Ok(vec) - }) + if let Meta::NameValue(MetaNameValue { + lit: Lit::Str(value), + .. + }) = attr.parse_meta()? + { + vec.push(VariantMeta::Documentation { value }) + } + Ok(vec) + }) } } diff --git a/strum_macros/src/helpers/type_props.rs b/strum_macros/src/helpers/type_props.rs index cdca79f3..0d49e04e 100644 --- a/strum_macros/src/helpers/type_props.rs +++ b/strum_macros/src/helpers/type_props.rs @@ -20,6 +20,7 @@ pub struct StrumTypeProperties { pub discriminant_name: Option, pub discriminant_others: Vec, pub discriminant_vis: Option, + pub use_phf: bool, } impl HasTypeProperties for DeriveInput { @@ -31,6 +32,7 @@ impl HasTypeProperties for DeriveInput { let mut serialize_all_kw = None; let mut ascii_case_insensitive_kw = None; + let mut use_phf_kw = None; let mut crate_module_path_kw = None; for meta in strum_meta { match meta { @@ -50,6 +52,14 @@ impl HasTypeProperties for DeriveInput { ascii_case_insensitive_kw = Some(kw); output.ascii_case_insensitive = true; } + EnumMeta::UsePhf(kw) => { + if let Some(fst_kw) = use_phf_kw { + return Err(occurrence_error(fst_kw, kw, "use_phf")); + } + + use_phf_kw = Some(kw); + output.use_phf = true; + } EnumMeta::Crate { crate_module_path, kw, diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index 32f9a780..de16c9c2 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -47,6 +47,9 @@ fn debug_print_generated(ast: &DeriveInput, toks: &TokenStream) { /// See the [Additional Attributes](https://docs.rs/strum/0.22/strum/additional_attributes/index.html) /// Section for more information on using this feature. /// +/// If you have a large enum, you may want to consider using the `use_phf` attribute here. It leverages +/// perfect hash functions to parse much quicker than a standard `match`. (MSRV 1.46) +/// /// # Example howto use `EnumString` /// ``` /// use std::str::FromStr; @@ -471,11 +474,11 @@ pub fn from_repr(input: proc_macro::TokenStream) -> proc_macro::TokenStream { /// Encode strings into the enum itself. The `strum_macros::EmumMessage` macro implements the `strum::EnumMessage` trait. /// `EnumMessage` looks for `#[strum(message="...")]` attributes on your variants. /// You can also provided a `detailed_message="..."` attribute to create a seperate more detailed message than the first. -/// +/// /// `EnumMessage` also exposes the variants doc comments through `get_documentation()`. This is useful in some scenarios, /// but `get_message` should generally be preferred. Rust doc comments are intended for developer facing documentation, /// not end user messaging. -/// +/// /// ``` /// // You need to bring the trait into scope to use it /// use strum::EnumMessage; diff --git a/strum_macros/src/macros/strings/from_string.rs b/strum_macros/src/macros/strings/from_string.rs index d0754d29..a4cc4d15 100644 --- a/strum_macros/src/macros/strings/from_string.rs +++ b/strum_macros/src/macros/strings/from_string.rs @@ -18,8 +18,15 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { let strum_module_path = type_properties.crate_module_path(); let mut default_kw = None; - let mut default = quote! { _ => ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }; - let mut arms = Vec::new(); + let mut default = + quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }; + let mut phf_exact_match_arms = Vec::new(); + // We'll use the first one if there are many variants + let mut phf_lowercase_arms = Vec::new(); + // However if there are few variants we'll want to integrate these in the standard match to avoid alloc + let mut case_insensitive_arms_alternative = Vec::new(); + // Later we can add custom arms in there + let mut standard_match_arms = Vec::new(); for variant in variants { let ident = &variant.ident; let variant_properties = variant.get_variant_properties()?; @@ -45,26 +52,11 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { default_kw = Some(kw); default = quote! { - default => ::core::result::Result::Ok(#name::#ident(default.into())) + ::core::result::Result::Ok(#name::#ident(s.into())) }; continue; } - let is_ascii_case_insensitive = variant_properties - .ascii_case_insensitive - .unwrap_or(type_properties.ascii_case_insensitive); - // If we don't have any custom variants, add the default serialized name. - let attrs = variant_properties - .get_serializations(type_properties.case_style) - .into_iter() - .map(|serialization| { - if is_ascii_case_insensitive { - quote! { s if s.eq_ignore_ascii_case(#serialization) } - } else { - quote! { #serialization } - } - }); - let params = match &variant.fields { Fields::Unit => quote! {}, Fields::Unnamed(fields) => { @@ -81,19 +73,95 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { } }; - arms.push(quote! { #(#attrs => ::core::result::Result::Ok(#name::#ident #params)),* }); + let is_ascii_case_insensitive = variant_properties + .ascii_case_insensitive + .unwrap_or(type_properties.ascii_case_insensitive); + + // If we don't have any custom variants, add the default serialized name. + for serialization in variant_properties.get_serializations(type_properties.case_style) { + if type_properties.use_phf { + if !is_ascii_case_insensitive { + phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, }); + } else { + // In that case we'll store the lowercase values in phf, and lowercase at runtime + // before searching + // Unless there are few such variants, in that case we'll use the standard match with + // eq_ignore_ascii_case to avoid allocating + case_insensitive_arms_alternative.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }); + + let mut ser_string = serialization.value(); + ser_string.make_ascii_lowercase(); + let serialization = syn::LitStr::new(&ser_string, serialization.span()); + phf_lowercase_arms.push(quote! { #serialization => #name::#ident #params, }); + } + } else { + standard_match_arms.push(if !is_ascii_case_insensitive { + quote! { #serialization => #name::#ident #params, } + } else { + quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, } + }); + } + } + } + + // Probably under that string allocation is more expensive than matching few times + // Proper threshold is not benchmarked - feel free to do so :) + if phf_lowercase_arms.len() <= 3 { + standard_match_arms.extend(case_insensitive_arms_alternative); + phf_lowercase_arms.clear(); } - arms.push(default); + let use_phf = if phf_exact_match_arms.is_empty() && phf_lowercase_arms.is_empty() { + quote!() + } else { + quote! { + use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf; + } + }; + let phf_body = if phf_exact_match_arms.is_empty() { + quote!() + } else { + quote! { + static PHF: phf::Map<&'static str, #name> = phf::phf_map! { + #(#phf_exact_match_arms)* + }; + if let Some(value) = PHF.get(s).cloned() { + return ::core::result::Result::Ok(value); + } + } + }; + let phf_lowercase_body = if phf_lowercase_arms.is_empty() { + quote!() + } else { + quote! { + static PHF_LOWERCASE: phf::Map<&'static str, #name> = phf::phf_map! { + #(#phf_lowercase_arms)* + }; + if let Some(value) = PHF_LOWERCASE.get(&s.to_ascii_lowercase()).cloned() { + return ::core::result::Result::Ok(value); + } + } + }; + let standard_match_body = if standard_match_arms.is_empty() { + default + } else { + quote! { + ::core::result::Result::Ok(match s { + #(#standard_match_arms)* + _ => return #default, + }) + } + }; let from_str = quote! { #[allow(clippy::use_self)] impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause { type Err = #strum_module_path::ParseError; fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , ::Err> { - match s { - #(#arms),* - } + #use_phf + #phf_body + #phf_lowercase_body + #standard_match_body } } }; diff --git a/strum_tests/Cargo.toml b/strum_tests/Cargo.toml index 796a0699..872a3b0f 100644 --- a/strum_tests/Cargo.toml +++ b/strum_tests/Cargo.toml @@ -4,6 +4,10 @@ version = "0.24.0" edition = "2018" authors = ["Peter Glotfelty "] +[features] +default = ["test_phf"] +test_phf = ["strum/phf"] + [dependencies] strum = { path = "../strum", features = ["derive"] } strum_macros = { path = "../strum_macros", features = [] } diff --git a/strum_tests/tests/phf.rs b/strum_tests/tests/phf.rs new file mode 100644 index 00000000..163418e8 --- /dev/null +++ b/strum_tests/tests/phf.rs @@ -0,0 +1,33 @@ +#[cfg(feature = "test_phf")] +#[test] +fn from_str_with_phf() { + #[derive(Debug, PartialEq, Eq, Clone, strum::EnumString)] + #[strum(use_phf)] + enum Color { + #[strum(ascii_case_insensitive)] + Blue, + Red, + } + assert_eq!("Red".parse::().unwrap(), Color::Red); + assert_eq!("bLuE".parse::().unwrap(), Color::Blue); +} + +#[cfg(feature = "test_phf")] +#[test] +fn from_str_with_phf_big() { + // This tests PHF when there are many case insensitive variants + #[derive(Debug, PartialEq, Eq, Clone, strum::EnumString)] + #[strum(use_phf, ascii_case_insensitive)] + enum Enum { + Var1, + Var2, + Var3, + Var4, + Var5, + Var6, + Var7, + Var8, + Var9, + } + assert_eq!("vAr2".parse::().unwrap(), Enum::Var2); +}