diff --git a/CHANGELOG.md b/CHANGELOG.md index fdf29abb00e..02213a509bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133) - Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159) - Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181) +- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`. ### Changed diff --git a/src/types/any.rs b/src/types/any.rs index b54aa540d34..fb5d9b8ad2d 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -259,6 +259,66 @@ impl PyAny { } } + /// Tests whether this object is less than another. + /// + /// This is equivalent to the Python expression `self < other`. + pub fn lt(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Lt)?.is_true() + } + + /// Tests whether this object is less than or equal to another. + /// + /// This is equivalent to the Python expression `self <= other`. + pub fn le(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Le)?.is_true() + } + + /// Tests whether this object is equal to another. + /// + /// This is equivalent to the Python expression `self == other`. + pub fn eq(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Eq)?.is_true() + } + + /// Tests whether this object is not equal to another. + /// + /// This is equivalent to the Python expression `self != other`. + pub fn ne(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Ne)?.is_true() + } + + /// Tests whether this object is greater than another. + /// + /// This is equivalent to the Python expression `self > other`. + pub fn gt(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Gt)?.is_true() + } + + /// Tests whether this object is greater than or equal to another. + /// + /// This is equivalent to the Python expression `self >= other`. + pub fn ge(&self, other: O) -> PyResult + where + O: ToPyObject, + { + self.rich_compare(other, CompareOp::Ge)?.is_true() + } + /// Determines whether this object appears callable. /// /// This is equivalent to Python's [`callable()`][1] function. @@ -711,7 +771,6 @@ mod tests { types::{IntoPyDict, PyList, PyLong, PyModule}, Python, ToPyObject, }; - #[test] fn test_call_for_non_existing_method() { Python::with_gil(|py| { @@ -834,4 +893,101 @@ class SimpleClass: assert!(bad_haystack.contains(&irrelevant_needle).is_err()); }); } + + // This is intentionally not a test, it's a generic function used by the tests below. + fn test_eq_methods_generic(list: &[T]) + where + T: PartialEq + PartialOrd + ToPyObject, + { + Python::with_gil(|py| { + for a in list { + for b in list { + let a_py = a.to_object(py).into_ref(py); + let b_py = b.to_object(py).into_ref(py); + + assert_eq!( + a.lt(b), + a_py.lt(b_py).unwrap(), + "{} < {} should be {}.", + a_py, + b_py, + a.lt(b) + ); + assert_eq!( + a.le(b), + a_py.le(b_py).unwrap(), + "{} <= {} should be {}.", + a_py, + b_py, + a.le(b) + ); + assert_eq!( + a.eq(b), + a_py.eq(b_py).unwrap(), + "{} == {} should be {}.", + a_py, + b_py, + a.eq(b) + ); + assert_eq!( + a.ne(b), + a_py.ne(b_py).unwrap(), + "{} != {} should be {}.", + a_py, + b_py, + a.ne(b) + ); + assert_eq!( + a.gt(b), + a_py.gt(b_py).unwrap(), + "{} > {} should be {}.", + a_py, + b_py, + a.gt(b) + ); + assert_eq!( + a.ge(b), + a_py.ge(b_py).unwrap(), + "{} >= {} should be {}.", + a_py, + b_py, + a.ge(b) + ); + } + } + }); + } + + #[test] + fn test_eq_methods_integers() { + let ints = [-4, -4, 1, 2, 0, -100, 1_000_000]; + test_eq_methods_generic(&ints); + } + + #[test] + fn test_eq_methods_strings() { + let strings = ["Let's", "test", "some", "eq", "methods"]; + test_eq_methods_generic(&strings); + } + + #[test] + fn test_eq_methods_floats() { + let floats = [ + -1.0, + 2.5, + 0.0, + 3.0, + std::f64::consts::PI, + 10.0, + 10.0 / 3.0, + -1_000_000.0, + ]; + test_eq_methods_generic(&floats); + } + + #[test] + fn test_eq_methods_bools() { + let bools = [true, false]; + test_eq_methods_generic(&bools); + } }