Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-export macros #170

Merged
merged 6 commits into from Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions strum_macros/src/helpers/metadata.rs
Expand Up @@ -2,6 +2,7 @@ use proc_macro2::{Span, TokenStream};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse2, parse_str,
punctuated::Punctuated,
spanned::Spanned,
Attribute, DeriveInput, Ident, LitBool, LitStr, Path, Token, Variant, Visibility,
Expand All @@ -11,6 +12,7 @@ use super::case_style::CaseStyle;

pub mod kw {
use syn::custom_keyword;
pub use syn::token::Crate;

// enum metadata
custom_keyword!(serialize_all);
Expand All @@ -37,6 +39,10 @@ pub enum EnumMeta {
case_style: CaseStyle,
},
AsciiCaseInsensitive(kw::ascii_case_insensitive),
Crate {
kw: kw::Crate,
crate_module_path: Path,
},
}

impl Parse for EnumMeta {
Expand All @@ -47,6 +53,16 @@ impl Parse for EnumMeta {
input.parse::<Token![=]>()?;
let case_style = input.parse()?;
Ok(EnumMeta::SerializeAll { kw, case_style })
} else if lookahead.peek(kw::Crate) {
let kw = input.parse::<kw::Crate>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let crate_module_path = parse2(path_tokens)?;
Ok(EnumMeta::Crate {
kw,
crate_module_path,
})
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
Ok(EnumMeta::AsciiCaseInsensitive(kw))
Expand All @@ -61,6 +77,7 @@ impl Spanned for EnumMeta {
match self {
EnumMeta::SerializeAll { kw, .. } => kw.span(),
EnumMeta::AsciiCaseInsensitive(kw) => kw.span(),
EnumMeta::Crate { kw, .. } => kw.span(),
}
}
}
Expand Down
25 changes: 24 additions & 1 deletion strum_macros/src/helpers/type_props.rs
@@ -1,7 +1,7 @@
use proc_macro2::TokenStream;
use quote::quote;
use std::default::Default;
use syn::{DeriveInput, Ident, Path, Visibility};
use syn::{parse_quote, DeriveInput, Ident, Path, Visibility};

use super::case_style::CaseStyle;
use super::metadata::{DeriveInputExt, EnumDiscriminantsMeta, EnumMeta};
Expand All @@ -15,6 +15,7 @@ pub trait HasTypeProperties {
pub struct StrumTypeProperties {
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub crate_module_path: Option<Path>,
pub discriminant_derives: Vec<Path>,
pub discriminant_name: Option<Ident>,
pub discriminant_others: Vec<TokenStream>,
Expand All @@ -30,6 +31,7 @@ impl HasTypeProperties for DeriveInput {

let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
let mut crate_module_path_kw = None;
for meta in strum_meta {
match meta {
EnumMeta::SerializeAll { case_style, kw } => {
Expand All @@ -48,6 +50,17 @@ impl HasTypeProperties for DeriveInput {
ascii_case_insensitive_kw = Some(kw);
output.ascii_case_insensitive = true;
}
EnumMeta::Crate {
crate_module_path,
kw,
} => {
if let Some(fst_kw) = crate_module_path_kw {
return Err(occurrence_error(fst_kw, kw, "Crate"));
}

crate_module_path_kw = Some(kw);
output.crate_module_path = Some(crate_module_path);
}
}
}

Expand Down Expand Up @@ -83,3 +96,13 @@ impl HasTypeProperties for DeriveInput {
Ok(output)
}
}

impl StrumTypeProperties {
pub fn crate_module_path(&self) -> Path {
if let Some(path) = &self.crate_module_path {
parse_quote!(#path)
} else {
parse_quote!(::strum)
}
}
}
6 changes: 4 additions & 2 deletions strum_macros/src/macros/enum_count.rs
Expand Up @@ -2,13 +2,15 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput};

use crate::helpers::non_enum_error;
use crate::helpers::{non_enum_error, HasTypeProperties};

pub(crate) fn enum_count_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let n = match &ast.data {
Data::Enum(v) => v.variants.len(),
_ => return Err(non_enum_error()),
};
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

// Used in the quasi-quotation below as `#name`
let name = &ast.ident;
Expand All @@ -18,7 +20,7 @@ pub(crate) fn enum_count_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {

Ok(quote! {
// Implementation
impl #impl_generics ::strum::EnumCount for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::EnumCount for #name #ty_generics #where_clause {
const COUNT: usize = #n;
}
})
Expand Down
6 changes: 4 additions & 2 deletions strum_macros/src/macros/enum_iter.rs
Expand Up @@ -2,13 +2,15 @@ use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Data, DeriveInput, Ident};

use crate::helpers::{non_enum_error, HasStrumVariantProperties};
use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};

pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

if gen.lifetimes().count() > 0 {
return Err(syn::Error::new(
Expand Down Expand Up @@ -80,7 +82,7 @@ pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}

impl #impl_generics ::strum::IntoEnumIterator for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
type Iterator = #iter_name #ty_generics;
fn iter() -> #iter_name #ty_generics {
#iter_name {
Expand Down
3 changes: 2 additions & 1 deletion strum_macros/src/macros/enum_messages.rs
Expand Up @@ -13,6 +13,7 @@ pub fn enum_message_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
};

let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

let mut arms = Vec::new();
let mut detailed_arms = Vec::new();
Expand Down Expand Up @@ -79,7 +80,7 @@ pub fn enum_message_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}

Ok(quote! {
impl #impl_generics ::strum::EnumMessage for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::EnumMessage for #name #ty_generics #where_clause {
fn get_message(&self) -> ::core::option::Option<&'static str> {
match self {
#(#arms),*
Expand Down
6 changes: 4 additions & 2 deletions strum_macros/src/macros/enum_properties.rs
Expand Up @@ -2,7 +2,7 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput};

use crate::helpers::{non_enum_error, HasStrumVariantProperties};
use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};

pub fn enum_properties_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
Expand All @@ -11,6 +11,8 @@ pub fn enum_properties_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Data::Enum(v) => &v.variants,
_ => return Err(non_enum_error()),
};
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

let mut arms = Vec::new();
for variant in variants {
Expand Down Expand Up @@ -53,7 +55,7 @@ pub fn enum_properties_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}

Ok(quote! {
impl #impl_generics ::strum::EnumProperty for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::EnumProperty for #name #ty_generics #where_clause {
fn get_str(&self, prop: &str) -> ::core::option::Option<&'static str> {
match self {
#(#arms),*
Expand Down
3 changes: 2 additions & 1 deletion strum_macros/src/macros/enum_variant_names.rs
Expand Up @@ -16,6 +16,7 @@ pub fn enum_variant_names_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {

// Derives for the generated enum
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

let names = variants
.iter()
Expand All @@ -26,7 +27,7 @@ pub fn enum_variant_names_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
.collect::<syn::Result<Vec<_>>>()?;

Ok(quote! {
impl #impl_generics ::strum::VariantNames for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::VariantNames for #name #ty_generics #where_clause {
const VARIANTS: &'static [&'static str] = &[ #(#names),* ];
}
})
Expand Down
4 changes: 3 additions & 1 deletion strum_macros/src/macros/strings/as_ref_str.rs
Expand Up @@ -75,6 +75,8 @@ pub fn as_static_str_inner(
let name = &ast.ident;
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let arms = get_arms(ast)?;
let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

let mut generics = ast.generics.clone();
generics
Expand All @@ -88,7 +90,7 @@ pub fn as_static_str_inner(

Ok(match trait_variant {
GenerateTraitVariant::AsStaticStr => quote! {
impl #impl_generics ::strum::AsStaticRef<str> for #name #ty_generics #where_clause {
impl #impl_generics #strum_module_path::AsStaticRef<str> for #name #ty_generics #where_clause {
fn as_static(&self) -> &'static str {
match *self {
#(#arms),*
Expand Down
5 changes: 3 additions & 2 deletions strum_macros/src/macros/strings/from_string.rs
Expand Up @@ -15,10 +15,11 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
};

let type_properties = ast.get_type_properties()?;
let strum_module_path = type_properties.crate_module_path();

let mut default_kw = None;
let mut default =
quote! { _ => ::std::result::Result::Err(::strum::ParseError::VariantNotFound) };
quote! { _ => ::std::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
let mut arms = Vec::new();
for variant in variants {
let ident = &variant.ident;
Expand Down Expand Up @@ -89,7 +90,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Ok(quote! {
#[allow(clippy::use_self)]
impl #impl_generics ::std::str::FromStr for #name #ty_generics #where_clause {
type Err = ::strum::ParseError;
type Err = #strum_module_path::ParseError;
fn from_str(s: &str) -> ::std::result::Result< #name #ty_generics , Self::Err> {
match s {
#(#arms),*
Expand Down
24 changes: 24 additions & 0 deletions strum_tests/tests/enum_count.rs
Expand Up @@ -16,3 +16,27 @@ fn simple_test() {
assert_eq!(7, Week::COUNT);
assert_eq!(Week::iter().count(), Week::COUNT);
}

#[test]
fn crate_module_path_test() {
pub mod nested {
pub mod module {
pub use strum;
}
}

#[derive(Debug, EnumCount, EnumIter)]
#[strum(crate = "nested::module::strum")]
enum Week {
Sunday,
Monday,
Tuesday,
Wednesday,
Thursday,
Friday,
Saturday,
}

assert_eq!(7, Week::COUNT);
assert_eq!(Week::iter().count(), Week::COUNT);
}
23 changes: 23 additions & 0 deletions strum_tests/tests/enum_discriminants.rs
Expand Up @@ -279,3 +279,26 @@ fn override_visibility() {
private::PubDiscriminants::VariantB,
);
}

#[test]
fn crate_module_path_test() {
pub mod nested {
pub mod module {
pub use strum;
}
}

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumDiscriminants)]
#[strum_discriminants(derive(EnumIter))]
#[strum(crate = "nested::module::strum")]
enum Simple {
Variant0,
Variant1,
}

let discriminants = SimpleDiscriminants::iter().collect::<Vec<_>>();
let expected = vec![SimpleDiscriminants::Variant0, SimpleDiscriminants::Variant1];

assert_eq!(expected, discriminants);
}
34 changes: 34 additions & 0 deletions strum_tests/tests/enum_iter.rs
Expand Up @@ -175,3 +175,37 @@ fn take_nth_test() {
assert_eq!(None, iter.next());
assert_eq!(None, iter.next_back());
}

#[test]
fn crate_module_path_test() {
pub mod nested {
pub mod module {
pub use strum;
}
}

#[derive(Debug, Eq, PartialEq, EnumIter)]
#[strum(crate = "nested::module::strum")]
enum Week {
Sunday,
Monday,
Tuesday,
Wednesday,
Thursday,
Friday,
Saturday,
}

let results = Week::iter().collect::<Vec<_>>();
let expected = vec![
Week::Sunday,
Week::Monday,
Week::Tuesday,
Week::Wednesday,
Week::Thursday,
Week::Friday,
Week::Saturday,
];

assert_eq!(expected, results);
}
28 changes: 28 additions & 0 deletions strum_tests/tests/enum_message.rs
Expand Up @@ -76,3 +76,31 @@ fn get_serializations() {
(Brightness::BrightWhite).get_serializations()
);
}

#[test]
fn crate_module_path_test() {
pub mod nested {
pub mod module {
pub use strum;
}
}

#[allow(dead_code)]
#[derive(Debug, Eq, PartialEq, EnumMessage)]
#[strum(crate = "nested::module::strum")]
enum Pets {
#[strum(message = "I'm a dog")]
Dog,
#[strum(message = "I'm a cat")]
#[strum(detailed_message = "I'm a very exquisite striped cat")]
Cat,
#[strum(detailed_message = "My fish is named Charles McFish")]
Fish,
Bird,
#[strum(disabled)]
Hamster,
}

assert_eq!("I'm a dog", (Pets::Dog).get_message().unwrap());
assert_eq!("I'm a dog", (Pets::Dog).get_detailed_message().unwrap());
}