From a4f59bc14e16dfd92024f38d11c2f8d8a7776a2e Mon Sep 17 00:00:00 2001 From: Gabriel Smith Date: Thu, 16 Jun 2022 15:18:55 -0400 Subject: [PATCH] macros: Support #[pyo3(name)] on enum variants --- CHANGELOG.md | 1 + pyo3-macros-backend/src/pyclass.rs | 93 +++++++++++++++++++++++++----- tests/test_enum.rs | 15 +++++ 3 files changed, 94 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 01fdc11497b..78848b55b4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `PyCode` and `PyFrame` high level objects. [#2408](https://github.com/PyO3/pyo3/pull/2408) - Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423) - Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428) +- Supprt `#[pyo3(name)]` on enum variants [#2457](https://github.com/PyO3/pyo3/pull/2457) ### Changed diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 00b77cf385a..4e3dca9fb1d 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -337,11 +337,11 @@ struct PyClassEnum<'a> { // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__. // This matters when the underlying representation may not fit in `isize`. repr_type: syn::Ident, - variants: Vec>, + variants: Vec<(PyClassEnumVariant<'a>, VariantPyO3Options)>, } impl<'a> PyClassEnum<'a> { - fn new(enum_: &'a syn::ItemEnum) -> syn::Result { + fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result { fn is_numeric_type(t: &syn::Ident) -> bool { [ "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize", @@ -369,7 +369,7 @@ impl<'a> PyClassEnum<'a> { let variants = enum_ .variants - .iter() + .iter_mut() .map(extract_variant_data) .collect::>()?; Ok(Self { @@ -406,6 +406,57 @@ pub fn build_py_enum( Ok(impl_enum(enum_, &args, doc, method_type)) } +fn get_variant_python_name<'a>( + variant: &'a syn::Ident, + options: &'a VariantPyO3Options, +) -> Cow<'a, syn::Ident> { + options + .name + .as_ref() + .map(|name_attr| Cow::Borrowed(&name_attr.value.0)) + .unwrap_or_else(|| Cow::Owned(variant.unraw())) +} + +/// `#[pyo3()]` options for pyclass enum variants +struct VariantPyO3Options { + name: Option, +} + +enum VariantPyO3Option { + Name(NameAttribute), +} + +impl Parse for VariantPyO3Option { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(attributes::kw::name) { + input.parse().map(VariantPyO3Option::Name) + } else { + Err(lookahead.error()) + } + } +} + +impl VariantPyO3Options { + fn take_pyo3_options(attrs: &mut Vec) -> Result { + let mut options = VariantPyO3Options { name: None }; + + for option in take_pyo3_options(attrs)? { + match option { + VariantPyO3Option::Name(name) => { + ensure_spanned!( + options.name.is_none(), + name.span() => "`name` may only be specified once" + ); + options.name = Some(name); + } + } + } + + Ok(options) + } +} + fn impl_enum( enum_: PyClassEnum<'_>, args: &PyClassArgs, @@ -429,10 +480,14 @@ fn impl_enum_class( let pytypeinfo = impl_pytypeinfo(cls, args, None); let (default_repr, default_repr_slot) = { - let variants_repr = variants.iter().map(|variant| { + let variants_repr = variants.iter().map(|(variant, options)| { let variant_name = variant.ident; // Assuming all variants are unit variants because they are the only type we support. - let repr = format!("{}.{}", get_class_python_name(&cls, args), variant_name); + let repr = format!( + "{}.{}", + get_class_python_name(&cls, args), + get_variant_python_name(variant_name, options) + ); quote! { #cls::#variant_name => #repr, } }); let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! { @@ -450,7 +505,7 @@ fn impl_enum_class( let (default_int, default_int_slot) = { // This implementation allows us to convert &T to #repr_type without implementing `Copy` - let variants_to_int = variants.iter().map(|variant| { + let variants_to_int = variants.iter().map(|(variant, _)| { let variant_name = variant.ident; quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, } }); @@ -466,7 +521,7 @@ fn impl_enum_class( }; let (default_richcmp, default_richcmp_slot) = { - let variants_eq = variants.iter().map(|variant| { + let variants_eq = variants.iter().map(|(variant, _)| { let variant_name = variant.ident; quote! { (#cls::#variant_name, #cls::#variant_name) => @@ -510,7 +565,12 @@ fn impl_enum_class( cls, args, methods_type, - enum_default_methods(cls, variants.iter().map(|v| v.ident)), + enum_default_methods( + cls, + variants + .iter() + .map(|(v, o)| (v.ident, get_variant_python_name(v.ident, o))), + ), default_slots, ) .doc(doc) @@ -556,33 +616,36 @@ fn generate_default_protocol_slot( fn enum_default_methods<'a>( cls: &'a syn::Ident, - unit_variant_names: impl IntoIterator, + unit_variant_names: impl IntoIterator)>, ) -> Vec { let cls_type = syn::parse_quote!(#cls); - let variant_to_attribute = |ident: &syn::Ident| ConstSpec { - rust_ident: ident.clone(), + let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec { + rust_ident: var_ident.clone(), attributes: ConstAttributes { is_class_attr: true, name: Some(NameAttribute { kw: syn::parse_quote! { name }, - value: NameLitStr(ident.clone()), + value: NameLitStr(py_ident.clone()), }), deprecations: Default::default(), }, }; unit_variant_names .into_iter() - .map(|var| gen_py_const(&cls_type, &variant_to_attribute(var))) + .map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name))) .collect() } -fn extract_variant_data(variant: &syn::Variant) -> syn::Result> { +fn extract_variant_data( + variant: &mut syn::Variant, +) -> syn::Result<(PyClassEnumVariant<'_>, VariantPyO3Options)> { use syn::Fields; let ident = match variant.fields { Fields::Unit => &variant.ident, _ => bail_spanned!(variant.span() => "Currently only support unit variants."), }; - Ok(PyClassEnumVariant { ident }) + let options = VariantPyO3Options::take_pyo3_options(&mut variant.attrs)?; + Ok((PyClassEnumVariant { ident }, options)) } fn descriptors_to_items( diff --git a/tests/test_enum.rs b/tests/test_enum.rs index d40bf33b17e..c40b2f7cde1 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -175,3 +175,18 @@ fn test_rename_enum_repr_correct() { py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'"); }) } + +#[pyclass] +#[derive(Debug, PartialEq, Clone)] +pub enum RenameVariantEnum { + #[pyo3(name = "VARIANT")] + Variant, +} + +#[test] +fn test_rename_variant_repr_correct() { + Python::with_gil(|py| { + let var1 = Py::new(py, RenameVariantEnum::Variant).unwrap(); + py_assert!(py, var1, "repr(var1) == 'RenameVariantEnum.VARIANT'"); + }) +}