Skip to content

Commit

Permalink
Implement __int__ and __richcmp__.
Browse files Browse the repository at this point in the history
  • Loading branch information
jovenlin0527 committed Dec 10, 2021
1 parent 27b00d7 commit 0c3c952
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 15 deletions.
64 changes: 62 additions & 2 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ struct PyClassEnum<'a> {
ident: &'a syn::Ident,
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
// This matters when the underlying representation may not fit in `isize`.
#[allow(unused, dead_code)]
repr_type: syn::Ident,
variants: Vec<PyClassEnumVariant<'a>>,
}
Expand Down Expand Up @@ -522,7 +521,68 @@ fn impl_enum_class(
}
};

let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
let repr_type = &enum_.repr_type;

let default_int = {
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
let variants_to_int = variants.iter().map(|variant| {
let variant_name = variant.ident;
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
});
quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
#[pyo3(name = "__int__")]
fn __pyo3__int__(&self) -> #repr_type {
match self {
#(#variants_to_int)*
}
}
}
};

let default_richcmp = {
let variants_eq = variants.iter().map(|variant| {
let variant_name = variant.ident;
quote! {
(#cls::#variant_name, #cls::#variant_name) =>
Ok(true.to_object(py)),
}
});
quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
#[pyo3(name = "__richcmp__")]
fn __pyo3__richcmp__(
&self,
py: _pyo3::Python,
other: &_pyo3::PyAny,
op: _pyo3::basic::CompareOp
) -> _pyo3::PyResult<_pyo3::PyObject> {
use _pyo3::conversion::ToPyObject;
use ::core::result::Result::*;
match op {
_pyo3::basic::CompareOp::Eq => {
if let Ok(i) = other.extract::<#repr_type>() {
let self_val = self.__pyo3__int__();
return Ok((self_val == i).to_object(py));
}
let other = other.extract::<_pyo3::PyRef<Self>>()?;
let other = &*other;
match (self, other) {
#(#variants_eq)*
_ => Ok(false.to_object(py)),
}
}
_ => Ok(py.NotImplemented()),
}
}
}
};

let default_impls =
gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp, default_int]);

Ok(quote! {
const _: () = {
use #krate as _pyo3;
Expand Down
91 changes: 78 additions & 13 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ pub enum MyEnum {

#[test]
fn test_enum_class_attr() {
let gil = Python::acquire_gil();
let py = gil.python();
let my_enum = py.get_type::<MyEnum>();
py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None");
py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None");
py_run!(py, my_enum, "my_enum.Variant = None");
Python::with_gil(|py| {
let my_enum = py.get_type::<MyEnum>();
let var = Py::new(py, MyEnum::Variant).unwrap();
py_assert!(py, my_enum var, "my_enum.Variant == var");
})
}

#[pyfunction]
Expand All @@ -28,7 +27,6 @@ fn return_enum() -> MyEnum {
}

#[test]
#[ignore] // need to implement __eq__
fn test_return_enum() {
let gil = Python::acquire_gil();
let py = gil.python();
Expand All @@ -44,14 +42,24 @@ fn enum_arg(e: MyEnum) {
}

#[test]
#[ignore] // need to implement __eq__
fn test_enum_arg() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
let mynum = py.get_type::<MyEnum>();
Python::with_gil(|py| {
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
let mynum = py.get_type::<MyEnum>();

py_run!(py, f mynum, "f(mynum.OtherVariant)")
})
}

py_run!(py, f mynum, "f(mynum.Variant)")
#[test]
fn test_enum_eq() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::Variant).unwrap();
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1 var2, "var1 == var2");
py_assert!(py, var1 other_var, "var1 != other_var");
})
}

#[test]
Expand Down Expand Up @@ -85,6 +93,63 @@ fn test_custom_discriminant() {
})
}

#[test]
fn test_enum_to_int() {
Python::with_gil(|py| {
let one = Py::new(py, CustomDiscriminant::One).unwrap();
py_assert!(py, one, "int(one) == 1");
let v = Py::new(py, MyEnum::Variant).unwrap();
let v_value = MyEnum::Variant as isize;
py_run!(py, v v_value, "int(v) == v_value");
})
}

#[test]
fn test_enum_compare_int() {
Python::with_gil(|py| {
let one = Py::new(py, CustomDiscriminant::One).unwrap();
py_run!(
py,
one,
r#"
assert one == 1
assert 1 == one
assert one != 2
"#
)
})
}

#[pyclass]
#[repr(u8)]
enum SmallEnum {
V = 1,
}

#[test]
fn test_enum_compare_int_no_throw_when_overflow() {
Python::with_gil(|py| {
let v = Py::new(py, SmallEnum::V).unwrap();
py_assert!(py, v, "v != 1<<30")
})
}

#[pyclass]
#[repr(usize)]
enum BigEnum {
V = usize::MAX,
}

#[test]
fn test_big_enum_no_overflow() {
Python::with_gil(|py| {
let usize_max = usize::MAX;
let v = Py::new(py, BigEnum::V).unwrap();
py_assert!(py, usize_max v, "v == usize_max");
py_assert!(py, usize_max v, "int(v) == usize_max");
})
}

#[pyclass]
#[repr(u16, align(8))]
enum TestReprParse {
Expand Down

0 comments on commit 0c3c952

Please sign in to comment.