diff --git a/strum/src/lib.rs b/strum/src/lib.rs index 9f7b9de6..fe12b92c 100644 --- a/strum/src/lib.rs +++ b/strum/src/lib.rs @@ -217,6 +217,7 @@ DocumentMacroRexports! { EnumProperty, EnumString, EnumVariantNames, + FromRepr, IntoStaticStr, ToString } diff --git a/strum_macros/Cargo.toml b/strum_macros/Cargo.toml index ee84210a..fcf643c6 100644 --- a/strum_macros/Cargo.toml +++ b/strum_macros/Cargo.toml @@ -22,6 +22,7 @@ name = "strum_macros" heck = "0.3" proc-macro2 = "1.0" quote = "1.0" +rustversion = "1.0" syn = { version = "1.0", features = ["parsing", "extra-traits"] } [dev-dependencies] diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index 38eb0b39..ad3dabe5 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -374,6 +374,86 @@ pub fn enum_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { toks.into() } +/// Add a function to enum that allows accessing variants by its discriminant +/// +/// This macro adds a standalone function to obtain an enum variant by its discriminant. The macro adds +/// `from_repr(discriminant: usize) -> Option` as a standalone function on the enum. For +/// variants with additional data, the returned variant will use the `Default` trait to fill the +/// data. The discriminant follows the same rules as `rustc`. The first discriminant is zero and each +/// successive variant has a discriminant of one greater than the previous variant, expect where an +/// explicit discriminant is specified. The type of the discriminant will match the `repr` type if +/// it is specifed. +/// +/// When the macro is applied using rustc >= 1.46 and when there is no additional data on any of +/// the variants, the `from_repr` function is marked `const`. rustc >= 1.46 is required +/// to allow `match` statements in `const fn`. The no additional data requirement is due to the +/// inability to use `Default::default()` in a `const fn`. +/// +/// You cannot derive `FromRepr` on any type with a lifetime bound (`<'a>`) because the function would surely +/// create [unbounded lifetimes](https://doc.rust-lang.org/nightly/nomicon/unbounded-lifetimes.html). +/// +/// ``` +/// +/// use strum_macros::FromRepr; +/// +/// #[derive(FromRepr, Debug, PartialEq)] +/// enum Color { +/// Red, +/// Green { range: usize }, +/// Blue(usize), +/// Yellow, +/// } +/// +/// assert_eq!(Some(Color::Red), Color::from_repr(0)); +/// assert_eq!(Some(Color::Green {range: 0}), Color::from_repr(1)); +/// assert_eq!(Some(Color::Blue(0)), Color::from_repr(2)); +/// assert_eq!(Some(Color::Yellow), Color::from_repr(3)); +/// assert_eq!(None, Color::from_repr(4)); +/// +/// // Custom discriminant tests +/// #[derive(FromRepr, Debug, PartialEq)] +/// #[repr(u8)] +/// enum Vehicle { +/// Car = 1, +/// Truck = 3, +/// } +/// +/// assert_eq!(None, Vehicle::from_repr(0)); +/// ``` +#[rustversion::attr(since(1.46),doc=" +`const` tests (only works in rust >= 1.46) +``` +use strum_macros::FromRepr; + +#[derive(FromRepr, Debug, PartialEq)] +#[repr(u8)] +enum Number { + One = 1, + Three = 3, +} + +// This test confirms that the function works in a `const` context +const fn number_from_repr(d: u8) -> Option { + Number::from_repr(d) +} +assert_eq!(None, number_from_repr(0)); +assert_eq!(Some(Number::One), number_from_repr(1)); +assert_eq!(None, number_from_repr(2)); +assert_eq!(Some(Number::Three), number_from_repr(3)); +assert_eq!(None, number_from_repr(4)); +``` +")] + +#[proc_macro_derive(FromRepr, attributes(strum))] +pub fn from_repr(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(input as DeriveInput); + + let toks = macros::from_repr::from_repr_inner(&ast) + .unwrap_or_else(|err| err.to_compile_error()); + debug_print_generated(&ast, &toks); + toks.into() +} + /// Add a verbose message to an enum variant. /// /// Encode strings into the enum itself. The `strum_macros::EmumMessage` macro implements the `strum::EnumMessage` trait. diff --git a/strum_macros/src/macros/from_repr.rs b/strum_macros/src/macros/from_repr.rs new file mode 100644 index 00000000..3ea6d1b8 --- /dev/null +++ b/strum_macros/src/macros/from_repr.rs @@ -0,0 +1,140 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use syn::{Data, DeriveInput, PathArguments, Type, TypeParen}; + +use crate::helpers::{non_enum_error, HasStrumVariantProperties}; + +pub fn from_repr_inner(ast: &DeriveInput) -> syn::Result { + let name = &ast.ident; + let gen = &ast.generics; + let (impl_generics, ty_generics, where_clause) = gen.split_for_impl(); + let vis = &ast.vis; + let attrs = &ast.attrs; + + let mut discriminant_type: Type = syn::parse("usize".parse().unwrap()).unwrap(); + for attr in attrs { + let path = &attr.path; + let tokens = &attr.tokens; + if path.leading_colon.is_some() { + continue; + } + if path.segments.len() != 1 { + continue; + } + let segment = path.segments.first().unwrap(); + if segment.ident != "repr" { + continue; + } + if segment.arguments != PathArguments::None { + continue; + } + let typ_paren = match syn::parse2::(tokens.clone()) { + Ok(Type::Paren(TypeParen { elem, .. })) => *elem, + _ => continue, + }; + let inner_path = match &typ_paren { + Type::Path(t) => t, + _ => continue, + }; + if let Some(seg) = inner_path.path.segments.last() { + for t in &[ + "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize", + ] { + if seg.ident == t { + discriminant_type = typ_paren; + break; + } + } + } + } + + if gen.lifetimes().count() > 0 { + return Err(syn::Error::new( + Span::call_site(), + "This macro doesn't support enums with lifetimes. \ + The resulting enums would be unbounded.", + )); + } + + let variants = match &ast.data { + Data::Enum(v) => &v.variants, + _ => return Err(non_enum_error()), + }; + + let mut arms = Vec::new(); + let mut constant_defs = Vec::new(); + let mut has_additional_data = false; + let mut prev_const_var_ident = None; + for variant in variants { + use syn::Fields::*; + + if variant.get_variant_properties()?.disabled.is_some() { + continue; + } + + let ident = &variant.ident; + let params = match &variant.fields { + Unit => quote! {}, + Unnamed(fields) => { + has_additional_data = true; + let defaults = ::std::iter::repeat(quote!(::core::default::Default::default())) + .take(fields.unnamed.len()); + quote! { (#(#defaults),*) } + } + Named(fields) => { + has_additional_data = true; + let fields = fields + .named + .iter() + .map(|field| field.ident.as_ref().unwrap()); + quote! { {#(#fields: ::core::default::Default::default()),*} } + } + }; + + use heck::ShoutySnakeCase; + let const_var_str = format!("{}_DISCRIMINANT", variant.ident).to_shouty_snake_case(); + let const_var_ident = format_ident!("{}", const_var_str); + + let const_val_expr = match &variant.discriminant { + Some((_, expr)) => quote! { #expr }, + None => match &prev_const_var_ident { + Some(prev) => quote! { #prev + 1 }, + None => quote! { 0 }, + }, + }; + + constant_defs + .push(quote! {const #const_var_ident: #discriminant_type = #const_val_expr;}); + arms.push(quote! {v if v == #const_var_ident => ::core::option::Option::Some(#name::#ident #params)}); + + prev_const_var_ident = Some(const_var_ident); + } + + arms.push(quote! { _ => ::core::option::Option::None }); + + let const_if_possible = if has_additional_data { + quote! {} + } else { + #[rustversion::before(1.46)] + fn filter_by_rust_version(s: TokenStream) -> TokenStream { + quote! {} + } + + #[rustversion::since(1.46)] + fn filter_by_rust_version(s: TokenStream) -> TokenStream { + s + } + filter_by_rust_version(quote! { const }) + }; + + Ok(quote! { + impl #impl_generics #name #ty_generics #where_clause { + #vis #const_if_possible fn from_repr(discriminant: #discriminant_type) -> Option<#name #ty_generics> { + #(#constant_defs)* + match discriminant { + #(#arms),* + } + } + } + }) +} diff --git a/strum_macros/src/macros/mod.rs b/strum_macros/src/macros/mod.rs index 9b707fbb..b4129697 100644 --- a/strum_macros/src/macros/mod.rs +++ b/strum_macros/src/macros/mod.rs @@ -4,6 +4,7 @@ pub mod enum_iter; pub mod enum_messages; pub mod enum_properties; pub mod enum_variant_names; +pub mod from_repr; mod strings; diff --git a/strum_tests/Cargo.toml b/strum_tests/Cargo.toml index a4b3470e..28664a28 100644 --- a/strum_tests/Cargo.toml +++ b/strum_tests/Cargo.toml @@ -13,4 +13,7 @@ structopt = "0.2.18" bitflags = "=1.2" [build-dependencies] -version_check = "0.9.2" \ No newline at end of file +version_check = "0.9.2" + +[dev-dependencies] +rustversion = "1.0" diff --git a/strum_tests/tests/from_repr.rs b/strum_tests/tests/from_repr.rs new file mode 100644 index 00000000..bb7263f6 --- /dev/null +++ b/strum_tests/tests/from_repr.rs @@ -0,0 +1,63 @@ +use strum::FromRepr; + +#[derive(Debug, FromRepr, PartialEq)] +#[repr(u8)] +enum Week { + Sunday, + Monday, + Tuesday, + Wednesday, + Thursday, + Friday = 4 + 3, + Saturday = 8, +} + +#[test] +fn simple_test() { + assert_eq!(Week::from_repr(0), Some(Week::Sunday)); + assert_eq!(Week::from_repr(1), Some(Week::Monday)); + assert_eq!(Week::from_repr(6), None); + assert_eq!(Week::from_repr(7), Some(Week::Friday)); + assert_eq!(Week::from_repr(8), Some(Week::Saturday)); + assert_eq!(Week::from_repr(9), None); +} + +#[rustversion::since(1.46)] +#[test] +fn const_test() { + // This is to test that it works in a const fn + const fn from_repr(discriminant: u8) -> Option { + Week::from_repr(discriminant) + } + assert_eq!(from_repr(0), Some(Week::Sunday)); + assert_eq!(from_repr(1), Some(Week::Monday)); + assert_eq!(from_repr(6), None); + assert_eq!(from_repr(7), Some(Week::Friday)); + assert_eq!(from_repr(8), Some(Week::Saturday)); + assert_eq!(from_repr(9), None); +} + +#[test] +fn crate_module_path_test() { + pub mod nested { + pub mod module { + pub use strum; + } + } + + #[derive(Debug, FromRepr, PartialEq)] + #[strum(crate = "nested::module::strum")] + enum Week { + Sunday, + Monday, + Tuesday, + Wednesday, + Thursday, + Friday, + Saturday, + } + + assert_eq!(Week::from_repr(0), Some(Week::Sunday)); + assert_eq!(Week::from_repr(6), Some(Week::Saturday)); + assert_eq!(Week::from_repr(7), None); +}