Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EnumIndex implementation #185

Merged
merged 18 commits into from Nov 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions strum/src/lib.rs
Expand Up @@ -217,6 +217,7 @@ DocumentMacroRexports! {
EnumProperty,
EnumString,
EnumVariantNames,
FromRepr,
IntoStaticStr,
ToString
}
1 change: 1 addition & 0 deletions strum_macros/Cargo.toml
Expand Up @@ -22,6 +22,7 @@ name = "strum_macros"
heck = "0.3"
proc-macro2 = "1.0"
quote = "1.0"
rustversion = "1.0"
syn = { version = "1.0", features = ["parsing", "extra-traits"] }

[dev-dependencies]
Expand Down
80 changes: 80 additions & 0 deletions strum_macros/src/lib.rs
Expand Up @@ -374,6 +374,86 @@ pub fn enum_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
toks.into()
}

/// Add a function to enum that allows accessing variants by its discriminant
///
/// This macro adds a standalone function to obtain an enum variant by its discriminant. The macro adds
/// `from_repr(discriminant: usize) -> Option<YourEnum>` as a standalone function on the enum. For
/// variants with additional data, the returned variant will use the `Default` trait to fill the
/// data. The discriminant follows the same rules as `rustc`. The first discriminant is zero and each
/// successive variant has a discriminant of one greater than the previous variant, expect where an
/// explicit discriminant is specified. The type of the discriminant will match the `repr` type if
/// it is specifed.
///
/// When the macro is applied using rustc >= 1.46 and when there is no additional data on any of
/// the variants, the `from_repr` function is marked `const`. rustc >= 1.46 is required
/// to allow `match` statements in `const fn`. The no additional data requirement is due to the
/// inability to use `Default::default()` in a `const fn`.
///
/// You cannot derive `FromRepr` on any type with a lifetime bound (`<'a>`) because the function would surely
/// create [unbounded lifetimes](https://doc.rust-lang.org/nightly/nomicon/unbounded-lifetimes.html).
///
/// ```
///
/// use strum_macros::FromRepr;
///
/// #[derive(FromRepr, Debug, PartialEq)]
/// enum Color {
/// Red,
/// Green { range: usize },
/// Blue(usize),
/// Yellow,
/// }
///
/// assert_eq!(Some(Color::Red), Color::from_repr(0));
/// assert_eq!(Some(Color::Green {range: 0}), Color::from_repr(1));
/// assert_eq!(Some(Color::Blue(0)), Color::from_repr(2));
/// assert_eq!(Some(Color::Yellow), Color::from_repr(3));
/// assert_eq!(None, Color::from_repr(4));
///
/// // Custom discriminant tests
/// #[derive(FromRepr, Debug, PartialEq)]
/// #[repr(u8)]
/// enum Vehicle {
/// Car = 1,
/// Truck = 3,
/// }
///
/// assert_eq!(None, Vehicle::from_repr(0));
/// ```
#[rustversion::attr(since(1.46),doc="
`const` tests (only works in rust >= 1.46)
```
use strum_macros::FromRepr;

#[derive(FromRepr, Debug, PartialEq)]
#[repr(u8)]
enum Number {
One = 1,
Three = 3,
}

// This test confirms that the function works in a `const` context
const fn number_from_repr(d: u8) -> Option<Number> {
Number::from_repr(d)
}
assert_eq!(None, number_from_repr(0));
assert_eq!(Some(Number::One), number_from_repr(1));
assert_eq!(None, number_from_repr(2));
assert_eq!(Some(Number::Three), number_from_repr(3));
assert_eq!(None, number_from_repr(4));
```
")]

#[proc_macro_derive(FromRepr, attributes(strum))]
pub fn from_repr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(input as DeriveInput);

let toks = macros::from_repr::from_repr_inner(&ast)
.unwrap_or_else(|err| err.to_compile_error());
debug_print_generated(&ast, &toks);
toks.into()
}

/// Add a verbose message to an enum variant.
///
/// Encode strings into the enum itself. The `strum_macros::EmumMessage` macro implements the `strum::EnumMessage` trait.
Expand Down
140 changes: 140 additions & 0 deletions strum_macros/src/macros/from_repr.rs
@@ -0,0 +1,140 @@
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{Data, DeriveInput, PathArguments, Type, TypeParen};

use crate::helpers::{non_enum_error, HasStrumVariantProperties};

pub fn from_repr_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
let attrs = &ast.attrs;

let mut discriminant_type: Type = syn::parse("usize".parse().unwrap()).unwrap();
for attr in attrs {
let path = &attr.path;
let tokens = &attr.tokens;
if path.leading_colon.is_some() {
continue;
}
if path.segments.len() != 1 {
continue;
}
let segment = path.segments.first().unwrap();
if segment.ident != "repr" {
continue;
}
if segment.arguments != PathArguments::None {
continue;
}
let typ_paren = match syn::parse2::<Type>(tokens.clone()) {
Ok(Type::Paren(TypeParen { elem, .. })) => *elem,
_ => continue,
};
let inner_path = match &typ_paren {
Type::Path(t) => t,
_ => continue,
};
if let Some(seg) = inner_path.path.segments.last() {
for t in &[
"u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
] {
if seg.ident == t {
discriminant_type = typ_paren;
break;
}
}
}
}

if gen.lifetimes().count() > 0 {
return Err(syn::Error::new(
Span::call_site(),
"This macro doesn't support enums with lifetimes. \
The resulting enums would be unbounded.",
));
}

let variants = match &ast.data {
Data::Enum(v) => &v.variants,
_ => return Err(non_enum_error()),
};

let mut arms = Vec::new();
let mut constant_defs = Vec::new();
let mut has_additional_data = false;
let mut prev_const_var_ident = None;
for variant in variants {
use syn::Fields::*;

if variant.get_variant_properties()?.disabled.is_some() {
continue;
}

let ident = &variant.ident;
let params = match &variant.fields {
Unit => quote! {},
Unnamed(fields) => {
has_additional_data = true;
let defaults = ::std::iter::repeat(quote!(::core::default::Default::default()))
.take(fields.unnamed.len());
quote! { (#(#defaults),*) }
}
Named(fields) => {
has_additional_data = true;
let fields = fields
.named
.iter()
.map(|field| field.ident.as_ref().unwrap());
quote! { {#(#fields: ::core::default::Default::default()),*} }
}
};

use heck::ShoutySnakeCase;
let const_var_str = format!("{}_DISCRIMINANT", variant.ident).to_shouty_snake_case();
let const_var_ident = format_ident!("{}", const_var_str);

let const_val_expr = match &variant.discriminant {
Some((_, expr)) => quote! { #expr },
None => match &prev_const_var_ident {
Some(prev) => quote! { #prev + 1 },
None => quote! { 0 },
},
};

constant_defs
.push(quote! {const #const_var_ident: #discriminant_type = #const_val_expr;});
arms.push(quote! {v if v == #const_var_ident => ::core::option::Option::Some(#name::#ident #params)});

prev_const_var_ident = Some(const_var_ident);
}

arms.push(quote! { _ => ::core::option::Option::None });

let const_if_possible = if has_additional_data {
quote! {}
} else {
#[rustversion::before(1.46)]
fn filter_by_rust_version(s: TokenStream) -> TokenStream {
quote! {}
}

#[rustversion::since(1.46)]
fn filter_by_rust_version(s: TokenStream) -> TokenStream {
s
}
filter_by_rust_version(quote! { const })
};

Ok(quote! {
impl #impl_generics #name #ty_generics #where_clause {
#vis #const_if_possible fn from_repr(discriminant: #discriminant_type) -> Option<#name #ty_generics> {
#(#constant_defs)*
match discriminant {
#(#arms),*
}
}
}
})
}
1 change: 1 addition & 0 deletions strum_macros/src/macros/mod.rs
Expand Up @@ -4,6 +4,7 @@ pub mod enum_iter;
pub mod enum_messages;
pub mod enum_properties;
pub mod enum_variant_names;
pub mod from_repr;

mod strings;

Expand Down
5 changes: 4 additions & 1 deletion strum_tests/Cargo.toml
Expand Up @@ -13,4 +13,7 @@ structopt = "0.2.18"
bitflags = "=1.2"

[build-dependencies]
version_check = "0.9.2"
version_check = "0.9.2"

[dev-dependencies]
rustversion = "1.0"
63 changes: 63 additions & 0 deletions strum_tests/tests/from_repr.rs
@@ -0,0 +1,63 @@
use strum::FromRepr;

#[derive(Debug, FromRepr, PartialEq)]
#[repr(u8)]
enum Week {
Sunday,
Monday,
Tuesday,
Wednesday,
Thursday,
Friday = 4 + 3,
Saturday = 8,
}

#[test]
fn simple_test() {
assert_eq!(Week::from_repr(0), Some(Week::Sunday));
assert_eq!(Week::from_repr(1), Some(Week::Monday));
assert_eq!(Week::from_repr(6), None);
assert_eq!(Week::from_repr(7), Some(Week::Friday));
assert_eq!(Week::from_repr(8), Some(Week::Saturday));
assert_eq!(Week::from_repr(9), None);
}

#[rustversion::since(1.46)]
#[test]
fn const_test() {
// This is to test that it works in a const fn
const fn from_repr(discriminant: u8) -> Option<Week> {
Week::from_repr(discriminant)
}
assert_eq!(from_repr(0), Some(Week::Sunday));
assert_eq!(from_repr(1), Some(Week::Monday));
assert_eq!(from_repr(6), None);
assert_eq!(from_repr(7), Some(Week::Friday));
assert_eq!(from_repr(8), Some(Week::Saturday));
assert_eq!(from_repr(9), None);
}

#[test]
fn crate_module_path_test() {
pub mod nested {
pub mod module {
pub use strum;
}
}

#[derive(Debug, FromRepr, PartialEq)]
#[strum(crate = "nested::module::strum")]
enum Week {
Sunday,
Monday,
Tuesday,
Wednesday,
Thursday,
Friday,
Saturday,
}

assert_eq!(Week::from_repr(0), Some(Week::Sunday));
assert_eq!(Week::from_repr(6), Some(Week::Saturday));
assert_eq!(Week::from_repr(7), None);
}