Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added eq, ne, gt etc. methods. #2175

Merged
merged 8 commits into from Feb 25, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133)
- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159)
- Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181)
- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`.

### Changed

Expand Down
141 changes: 140 additions & 1 deletion src/types/any.rs
Expand Up @@ -259,6 +259,66 @@ impl PyAny {
}
}

/// Tests whether this object is less than another.
///
/// This is equivalent to the Python expression `self < other`.
pub fn lt<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Lt)
}

/// Tests whether this object is less than or equal to another.
///
/// This is equivalent to the Python expression `self <= other`.
pub fn le<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Le)
}

/// Tests whether this object is equal to another.
///
/// This is equivalent to the Python expression `self == other`.
pub fn eq<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Eq)
}

/// Tests whether this object is not equal to another.
///
/// This is equivalent to the Python expression `self != other`.
pub fn ne<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Ne)
}

/// Tests whether this object is greater than another.
///
/// This is equivalent to the Python expression `self > other`.
pub fn gt<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Gt)
}

/// Tests whether this object is greater than or equal to another.
///
/// This is equivalent to the Python expression `self >= other`.
pub fn ge<O>(&self, other: O) -> PyResult<&PyAny>
where
O: ToPyObject,
{
self.rich_compare(other, CompareOp::Ge)
}

/// Determines whether this object appears callable.
///
/// This is equivalent to Python's [`callable()`][1] function.
Expand Down Expand Up @@ -709,7 +769,7 @@ mod tests {
use crate::{
type_object::PyTypeObject,
types::{IntoPyDict, PyList, PyLong, PyModule},
Python, ToPyObject,
PyAny, PyResult, Python, ToPyObject,
};

#[test]
Expand Down Expand Up @@ -834,4 +894,83 @@ class SimpleClass:
assert!(bad_haystack.contains(&irrelevant_needle).is_err());
});
}

// This is intentionally not a test, it's a generic function used by the tests below.
fn test_eq_methods_generic<T>(list: &[T])
where
T: PartialEq + PartialOrd + ToPyObject,
{
Python::with_gil(|py| {
for a in list {
for b in list {
let a_py = a.to_object(py).into_ref(py);
let b_py = b.to_object(py).into_ref(py);
let unwrap_cmp = |cmp: PyResult<&PyAny>| cmp.unwrap().is_true().unwrap();
assert_eq!(
a.lt(b),
unwrap_cmp(a_py.lt(b_py)),
"{a_py} should be less than {b_py}"
Tom1380 marked this conversation as resolved.
Show resolved Hide resolved
);
assert_eq!(
a.le(b),
unwrap_cmp(a_py.le(b_py)),
"{a_py} should be less than or equal to {b_py}"
);
assert_eq!(
a.eq(b),
unwrap_cmp(a_py.eq(b_py)),
"{a_py} should be equal to {b_py}"
);
assert_eq!(
a.ne(b),
unwrap_cmp(a_py.ne(b_py)),
"{a_py} should not be equal to {b_py}"
);
assert_eq!(
a.gt(b),
unwrap_cmp(a_py.gt(b_py)),
"{a_py} should be greater than {b_py}"
);
assert_eq!(
a.ge(b),
unwrap_cmp(a_py.ge(b_py)),
"{a_py} should be greater than or equal to {b_py}"
);
}
}
});
}

#[test]
fn test_eq_methods_integers() {
let ints = [-4, -4, 1, 2, 0, -100, 1_000_000];
test_eq_methods_generic(&ints);
}

#[test]
fn test_eq_methods_strings() {
let strings = ["Let's", "test", "some", "eq", "methods"];
test_eq_methods_generic(&strings);
}

#[test]
fn test_eq_methods_floats() {
let floats = [
-1.0,
2.5,
0.0,
3.0,
std::f64::consts::PI,
10.0,
10.0 / 3.0,
-1_000_000.0,
];
test_eq_methods_generic(&floats);
}

#[test]
fn test_eq_methods_bools() {
let bools = [true, false];
test_eq_methods_generic(&bools);
}
}