Skip to content

Commit

Permalink
Support #[pyo3(name)] on enum variants
Browse files Browse the repository at this point in the history
  • Loading branch information
yodaldevoid committed Jun 16, 2022
1 parent 70574b4 commit 34667fe
Showing 1 changed file with 78 additions and 15 deletions.
93 changes: 78 additions & 15 deletions pyo3-macros-backend/src/pyclass.rs
Expand Up @@ -337,11 +337,11 @@ struct PyClassEnum<'a> {
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
// This matters when the underlying representation may not fit in `isize`.
repr_type: syn::Ident,
variants: Vec<PyClassEnumVariant<'a>>,
variants: Vec<(PyClassEnumVariant<'a>, VariantPyO3Options)>,
}

impl<'a> PyClassEnum<'a> {
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> {
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
fn is_numeric_type(t: &syn::Ident) -> bool {
[
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
Expand Down Expand Up @@ -369,7 +369,7 @@ impl<'a> PyClassEnum<'a> {

let variants = enum_
.variants
.iter()
.iter_mut()
.map(extract_variant_data)
.collect::<syn::Result<_>>()?;
Ok(Self {
Expand Down Expand Up @@ -406,6 +406,57 @@ pub fn build_py_enum(
Ok(impl_enum(enum_, &args, doc, method_type))
}

fn get_variant_python_name<'a>(
variant: &'a syn::Ident,
options: &'a VariantPyO3Options,
) -> Cow<'a, syn::Ident> {
options
.name
.as_ref()
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
.unwrap_or_else(|| Cow::Owned(variant.unraw()))
}

/// `#[pyo3()]` options for pyclass enum variants
struct VariantPyO3Options {
name: Option<NameAttribute>,
}

enum VariantPyO3Option {
Name(NameAttribute),
}

impl Parse for VariantPyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(VariantPyO3Option::Name)
} else {
Err(lookahead.error())
}
}
}

impl VariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = VariantPyO3Options { name: None };

for option in take_pyo3_options(attrs)? {
match option {
VariantPyO3Option::Name(name) => {
ensure_spanned!(
options.name.is_none(),
name.span() => "`name` may only be specified once"
);
options.name = Some(name);
}
}
}

Ok(options)
}
}

fn impl_enum(
enum_: PyClassEnum<'_>,
args: &PyClassArgs,
Expand All @@ -429,10 +480,14 @@ fn impl_enum_class(
let pytypeinfo = impl_pytypeinfo(cls, args, None);

let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
let variants_repr = variants.iter().map(|(variant, options)| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", get_class_python_name(&cls, args), variant_name);
let repr = format!(
"{}.{}",
get_class_python_name(&cls, args),
get_variant_python_name(variant_name, options)
);
quote! { #cls::#variant_name => #repr, }
});
let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! {
Expand All @@ -450,7 +505,7 @@ fn impl_enum_class(

let (default_int, default_int_slot) = {
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
let variants_to_int = variants.iter().map(|variant| {
let variants_to_int = variants.iter().map(|(variant, _)| {
let variant_name = variant.ident;
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
});
Expand All @@ -466,7 +521,7 @@ fn impl_enum_class(
};

let (default_richcmp, default_richcmp_slot) = {
let variants_eq = variants.iter().map(|variant| {
let variants_eq = variants.iter().map(|(variant, _)| {
let variant_name = variant.ident;
quote! {
(#cls::#variant_name, #cls::#variant_name) =>
Expand Down Expand Up @@ -510,7 +565,12 @@ fn impl_enum_class(
cls,
args,
methods_type,
enum_default_methods(cls, variants.iter().map(|v| v.ident)),
enum_default_methods(
cls,
variants
.iter()
.map(|(v, o)| (v.ident, get_variant_python_name(v.ident, o))),
),
default_slots,
)
.doc(doc)
Expand Down Expand Up @@ -556,33 +616,36 @@ fn generate_default_protocol_slot(

fn enum_default_methods<'a>(
cls: &'a syn::Ident,
unit_variant_names: impl IntoIterator<Item = &'a syn::Ident>,
unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
) -> Vec<TokenStream> {
let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |ident: &syn::Ident| ConstSpec {
rust_ident: ident.clone(),
let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
rust_ident: var_ident.clone(),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute {
kw: syn::parse_quote! { name },
value: NameLitStr(ident.clone()),
value: NameLitStr(py_ident.clone()),
}),
deprecations: Default::default(),
},
};
unit_variant_names
.into_iter()
.map(|var| gen_py_const(&cls_type, &variant_to_attribute(var)))
.map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name)))
.collect()
}

fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
fn extract_variant_data(
variant: &mut syn::Variant,
) -> syn::Result<(PyClassEnumVariant<'_>, VariantPyO3Options)> {
use syn::Fields;
let ident = match variant.fields {
Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
};
Ok(PyClassEnumVariant { ident })
let options = VariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok((PyClassEnumVariant { ident }, options))
}

fn descriptors_to_items(
Expand Down

0 comments on commit 34667fe

Please sign in to comment.