diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fc75ba1303..ba83944ca67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383) - Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460) +- Supprt `#[pyo3(name)]` on enum variants [#2457](https://github.com/PyO3/pyo3/pull/2457) ### Changed @@ -59,6 +60,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix FFI definition `PyGetSetDef` to have `*const c_char` for `doc` member (not `*mut c_char`). [#2439](https://github.com/PyO3/pyo3/pull/2439) - Fix `#[pyo3(from_py_with = "...")]` being ignored for 1-element tuple structs and transparent structs. [#2440](https://github.com/PyO3/pyo3/pull/2440) - Use `memoffset` for computing PyCell offsets [#2450](https://github.com/PyO3/pyo3/pull/2450) +- Fix incorrect enum names being returned by `repr` for enums renamed by `#[pyclass(name)]` [#2457](https://github.com/PyO3/pyo3/pull/2457) ## [0.16.5] - 2022-05-15 diff --git a/guide/src/class.md b/guide/src/class.md index 7124114a479..b87d8b0b4d8 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -897,6 +897,26 @@ Python::with_gil(|py| { }) ``` +Enums and their variants can also be renamed using `#[pyo3(name)]`. + +```rust +# use pyo3::prelude::*; +#[pyclass(name = "RenamedEnum")] +enum MyEnum { + #[pyo3(name = "UPPERCASE")] + Variant, +} + +Python::with_gil(|py| { + let x = Py::new(py, MyEnum::Variant).unwrap(); + let cls = py.get_type::(); + pyo3::py_run!(py, x cls, r#" + assert repr(x) == 'RenamedEnum.UPPERCASE' + assert x == cls.UPPERCASE + "#) +}) +``` + You may not use enums as a base class or let enums inherit from other classes. ```rust,compile_fail diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index e3ac571c342..bc1f5fa247b 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -330,7 +330,17 @@ fn impl_class( struct PyClassEnumVariant<'a> { ident: &'a syn::Ident, - /* currently have no more options */ + options: EnumVariantPyO3Options, +} + +impl<'a> PyClassEnumVariant<'a> { + fn python_name(&self) -> Cow<'_, syn::Ident> { + self.options + .name + .as_ref() + .map(|name_attr| Cow::Borrowed(&name_attr.value.0)) + .unwrap_or_else(|| Cow::Owned(self.ident.unraw())) + } } struct PyClassEnum<'a> { @@ -342,7 +352,7 @@ struct PyClassEnum<'a> { } impl<'a> PyClassEnum<'a> { - fn new(enum_: &'a syn::ItemEnum) -> syn::Result { + fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result { fn is_numeric_type(t: &syn::Ident) -> bool { [ "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize", @@ -370,7 +380,7 @@ impl<'a> PyClassEnum<'a> { let variants = enum_ .variants - .iter() + .iter_mut() .map(extract_variant_data) .collect::>()?; Ok(Self { @@ -407,6 +417,46 @@ pub fn build_py_enum( Ok(impl_enum(enum_, &args, doc, method_type)) } +/// `#[pyo3()]` options for pyclass enum variants +struct EnumVariantPyO3Options { + name: Option, +} + +enum EnumVariantPyO3Option { + Name(NameAttribute), +} + +impl Parse for EnumVariantPyO3Option { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(attributes::kw::name) { + input.parse().map(EnumVariantPyO3Option::Name) + } else { + Err(lookahead.error()) + } + } +} + +impl EnumVariantPyO3Options { + fn take_pyo3_options(attrs: &mut Vec) -> Result { + let mut options = EnumVariantPyO3Options { name: None }; + + for option in take_pyo3_options(attrs)? { + match option { + EnumVariantPyO3Option::Name(name) => { + ensure_spanned!( + options.name.is_none(), + name.span() => "`name` may only be specified once" + ); + options.name = Some(name); + } + } + } + + Ok(options) + } +} + fn impl_enum( enum_: PyClassEnum<'_>, args: &PyClassArgs, @@ -433,7 +483,11 @@ fn impl_enum_class( let variants_repr = variants.iter().map(|variant| { let variant_name = variant.ident; // Assuming all variants are unit variants because they are the only type we support. - let repr = format!("{}.{}", cls, variant_name); + let repr = format!( + "{}.{}", + get_class_python_name(cls, args), + variant.python_name(), + ); quote! { #cls::#variant_name => #repr, } }); let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! { @@ -511,7 +565,7 @@ fn impl_enum_class( cls, args, methods_type, - enum_default_methods(cls, variants.iter().map(|v| v.ident)), + enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name()))), default_slots, ) .doc(doc) @@ -557,33 +611,34 @@ fn generate_default_protocol_slot( fn enum_default_methods<'a>( cls: &'a syn::Ident, - unit_variant_names: impl IntoIterator, + unit_variant_names: impl IntoIterator)>, ) -> Vec { let cls_type = syn::parse_quote!(#cls); - let variant_to_attribute = |ident: &syn::Ident| ConstSpec { - rust_ident: ident.clone(), + let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec { + rust_ident: var_ident.clone(), attributes: ConstAttributes { is_class_attr: true, name: Some(NameAttribute { kw: syn::parse_quote! { name }, - value: NameLitStr(ident.clone()), + value: NameLitStr(py_ident.clone()), }), deprecations: Default::default(), }, }; unit_variant_names .into_iter() - .map(|var| gen_py_const(&cls_type, &variant_to_attribute(var))) + .map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name))) .collect() } -fn extract_variant_data(variant: &syn::Variant) -> syn::Result> { +fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result> { use syn::Fields; let ident = match variant.fields { Fields::Unit => &variant.ident, _ => bail_spanned!(variant.span() => "Currently only support unit variants."), }; - Ok(PyClassEnumVariant { ident }) + let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?; + Ok(PyClassEnumVariant { ident, options }) } fn descriptors_to_items( diff --git a/tests/test_enum.rs b/tests/test_enum.rs index d1f66a243d1..c40b2f7cde1 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -161,3 +161,32 @@ enum TestReprParse { fn test_repr_parse() { assert_eq!(std::mem::align_of::(), 8); } + +#[pyclass(name = "MyEnum")] +#[derive(Debug, PartialEq, Clone)] +pub enum RenameEnum { + Variant, +} + +#[test] +fn test_rename_enum_repr_correct() { + Python::with_gil(|py| { + let var1 = Py::new(py, RenameEnum::Variant).unwrap(); + py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'"); + }) +} + +#[pyclass] +#[derive(Debug, PartialEq, Clone)] +pub enum RenameVariantEnum { + #[pyo3(name = "VARIANT")] + Variant, +} + +#[test] +fn test_rename_variant_repr_correct() { + Python::with_gil(|py| { + let var1 = Py::new(py, RenameVariantEnum::Variant).unwrap(); + py_assert!(py, var1, "repr(var1) == 'RenameVariantEnum.VARIANT'"); + }) +}