Skip to content

Commit

Permalink
Merge pull request #2457 from yodaldevoid/enum_naming_improvements
Browse files Browse the repository at this point in the history
Enum naming improvements
  • Loading branch information
davidhewitt committed Jun 22, 2022
2 parents cdb3b6f + 845be04 commit 510c126
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383)
- Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a
Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460)
- Supprt `#[pyo3(name)]` on enum variants [#2457](https://github.com/PyO3/pyo3/pull/2457)

### Changed

Expand Down Expand Up @@ -59,6 +60,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix FFI definition `PyGetSetDef` to have `*const c_char` for `doc` member (not `*mut c_char`). [#2439](https://github.com/PyO3/pyo3/pull/2439)
- Fix `#[pyo3(from_py_with = "...")]` being ignored for 1-element tuple structs and transparent structs. [#2440](https://github.com/PyO3/pyo3/pull/2440)
- Use `memoffset` for computing PyCell offsets [#2450](https://github.com/PyO3/pyo3/pull/2450)
- Fix incorrect enum names being returned by `repr` for enums renamed by `#[pyclass(name)]` [#2457](https://github.com/PyO3/pyo3/pull/2457)

## [0.16.5] - 2022-05-15

Expand Down
20 changes: 20 additions & 0 deletions guide/src/class.md
Expand Up @@ -897,6 +897,26 @@ Python::with_gil(|py| {
})
```

Enums and their variants can also be renamed using `#[pyo3(name)]`.

```rust
# use pyo3::prelude::*;
#[pyclass(name = "RenamedEnum")]
enum MyEnum {
#[pyo3(name = "UPPERCASE")]
Variant,
}

Python::with_gil(|py| {
let x = Py::new(py, MyEnum::Variant).unwrap();
let cls = py.get_type::<MyEnum>();
pyo3::py_run!(py, x cls, r#"
assert repr(x) == 'RenamedEnum.UPPERCASE'
assert x == cls.UPPERCASE
"#)
})
```

You may not use enums as a base class or let enums inherit from other classes.

```rust,compile_fail
Expand Down
79 changes: 67 additions & 12 deletions pyo3-macros-backend/src/pyclass.rs
Expand Up @@ -330,7 +330,17 @@ fn impl_class(

struct PyClassEnumVariant<'a> {
ident: &'a syn::Ident,
/* currently have no more options */
options: EnumVariantPyO3Options,
}

impl<'a> PyClassEnumVariant<'a> {
fn python_name(&self) -> Cow<'_, syn::Ident> {
self.options
.name
.as_ref()
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
.unwrap_or_else(|| Cow::Owned(self.ident.unraw()))
}
}

struct PyClassEnum<'a> {
Expand All @@ -342,7 +352,7 @@ struct PyClassEnum<'a> {
}

impl<'a> PyClassEnum<'a> {
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> {
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
fn is_numeric_type(t: &syn::Ident) -> bool {
[
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
Expand Down Expand Up @@ -370,7 +380,7 @@ impl<'a> PyClassEnum<'a> {

let variants = enum_
.variants
.iter()
.iter_mut()
.map(extract_variant_data)
.collect::<syn::Result<_>>()?;
Ok(Self {
Expand Down Expand Up @@ -407,6 +417,46 @@ pub fn build_py_enum(
Ok(impl_enum(enum_, &args, doc, method_type))
}

/// `#[pyo3()]` options for pyclass enum variants
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
}

enum EnumVariantPyO3Option {
Name(NameAttribute),
}

impl Parse for EnumVariantPyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
} else {
Err(lookahead.error())
}
}
}

impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = EnumVariantPyO3Options { name: None };

for option in take_pyo3_options(attrs)? {
match option {
EnumVariantPyO3Option::Name(name) => {
ensure_spanned!(
options.name.is_none(),
name.span() => "`name` may only be specified once"
);
options.name = Some(name);
}
}
}

Ok(options)
}
}

fn impl_enum(
enum_: PyClassEnum<'_>,
args: &PyClassArgs,
Expand All @@ -433,7 +483,11 @@ fn impl_enum_class(
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", cls, variant_name);
let repr = format!(
"{}.{}",
get_class_python_name(cls, args),
variant.python_name(),
);
quote! { #cls::#variant_name => #repr, }
});
let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! {
Expand Down Expand Up @@ -511,7 +565,7 @@ fn impl_enum_class(
cls,
args,
methods_type,
enum_default_methods(cls, variants.iter().map(|v| v.ident)),
enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name()))),
default_slots,
)
.doc(doc)
Expand Down Expand Up @@ -557,33 +611,34 @@ fn generate_default_protocol_slot(

fn enum_default_methods<'a>(
cls: &'a syn::Ident,
unit_variant_names: impl IntoIterator<Item = &'a syn::Ident>,
unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
) -> Vec<MethodAndMethodDef> {
let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |ident: &syn::Ident| ConstSpec {
rust_ident: ident.clone(),
let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
rust_ident: var_ident.clone(),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute {
kw: syn::parse_quote! { name },
value: NameLitStr(ident.clone()),
value: NameLitStr(py_ident.clone()),
}),
deprecations: Default::default(),
},
};
unit_variant_names
.into_iter()
.map(|var| gen_py_const(&cls_type, &variant_to_attribute(var)))
.map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name)))
.collect()
}

fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
use syn::Fields;
let ident = match variant.fields {
Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
};
Ok(PyClassEnumVariant { ident })
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok(PyClassEnumVariant { ident, options })
}

fn descriptors_to_items(
Expand Down
29 changes: 29 additions & 0 deletions tests/test_enum.rs
Expand Up @@ -161,3 +161,32 @@ enum TestReprParse {
fn test_repr_parse() {
assert_eq!(std::mem::align_of::<TestReprParse>(), 8);
}

#[pyclass(name = "MyEnum")]
#[derive(Debug, PartialEq, Clone)]
pub enum RenameEnum {
Variant,
}

#[test]
fn test_rename_enum_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, RenameEnum::Variant).unwrap();
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
})
}

#[pyclass]
#[derive(Debug, PartialEq, Clone)]
pub enum RenameVariantEnum {
#[pyo3(name = "VARIANT")]
Variant,
}

#[test]
fn test_rename_variant_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, RenameVariantEnum::Variant).unwrap();
py_assert!(py, var1, "repr(var1) == 'RenameVariantEnum.VARIANT'");
})
}

0 comments on commit 510c126

Please sign in to comment.