Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CompareOp::matches #2460

Merged
merged 1 commit into from Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423)
- Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428)
- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383)
- Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a
Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460)

### Changed

Expand Down
30 changes: 26 additions & 4 deletions guide/src/class/object.md
Expand Up @@ -128,15 +128,15 @@ impl Number {
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
This method will be called with a value of `CompareOp` depending on the operation.

```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
#
# #[pyclass]
# struct Number(i32);
#
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
Expand All @@ -152,6 +152,28 @@ impl Number {
}
```

If you obtain the result by comparing two Rust values, as in this example, you
can take a shortcut using `CompareOp::matches`:

```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.0.cmp(&other.0))
}
}
```

It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
the given `CompareOp`.

### Truthyness

We'll consider `Number` to be `True` if it is nonzero:
Expand Down Expand Up @@ -229,4 +251,4 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
[`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html
[`Hasher`]: https://doc.rust-lang.org/std/hash/trait.Hasher.html
[`DefaultHasher`]: https://doc.rust-lang.org/std/collections/hash_map/struct.DefaultHasher.html
[SipHash]: https://en.wikipedia.org/wiki/SipHash
[SipHash]: https://en.wikipedia.org/wiki/SipHash
8 changes: 6 additions & 2 deletions guide/src/class/protocols.md
Expand Up @@ -70,8 +70,11 @@ given signatures should be interpreted as follows:
<details>
<summary>Return type</summary>
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
If the `object` is not of the type specified in the signature, the generated code will
automatically `return NotImplemented`.
If the second argument `object` is not of the type specified in the
signature, the generated code will automatically `return NotImplemented`.

You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
to the requested comparison.
</details>

- `__getattr__(<self>, object) -> object`
Expand Down Expand Up @@ -611,3 +614,4 @@ For details, look at the `#[pymethods]` regarding GC methods.
[`PySequenceProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/sequence/trait.PySequenceProtocol.html
[`PyIterProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/iter/trait.PyIterProtocol.html
[`PySequence`]: {{#PYO3_DOCS_URL}}/pyo3/types/struct.PySequence.html
[`CompareOp::matches`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.CompareOp.html#method.matches
48 changes: 48 additions & 0 deletions src/pyclass.rs
Expand Up @@ -10,6 +10,7 @@ use crate::{
IntoPy, IntoPyPointer, PyCell, PyErr, PyMethodDefType, PyObject, PyResult, PyTypeInfo, Python,
};
use std::{
cmp::Ordering,
convert::TryInto,
ffi::{CStr, CString},
os::raw::{c_char, c_int, c_uint, c_void},
Expand Down Expand Up @@ -452,6 +453,7 @@ pub enum CompareOp {
}

impl CompareOp {
/// Conversion from the C enum.
pub fn from_raw(op: c_int) -> Option<Self> {
match op {
ffi::Py_LT => Some(CompareOp::Lt),
Expand All @@ -463,6 +465,37 @@ impl CompareOp {
_ => None,
}
}

/// Returns if a Rust [`std::cmp::Ordering`] matches this ordering query.
///
/// Usage example:
///
/// ```rust
/// # use pyo3::prelude::*;
/// # use pyo3::class::basic::CompareOp;
///
/// #[pyclass]
/// struct Size {
/// size: usize
/// }
///
/// #[pymethods]
/// impl Size {
/// fn __richcmp__(&self, other: &Size, op: CompareOp) -> bool {
/// op.matches(self.size.cmp(&other.size))
/// }
/// }
/// ```
pub fn matches(&self, result: Ordering) -> bool {
match self {
CompareOp::Eq => result == Ordering::Equal,
CompareOp::Ne => result != Ordering::Equal,
CompareOp::Lt => result == Ordering::Less,
CompareOp::Le => result != Ordering::Greater,
CompareOp::Gt => result == Ordering::Greater,
CompareOp::Ge => result != Ordering::Less,
}
}
}

/// Output of `__next__` which can either `yield` the next value in the iteration, or
Expand Down Expand Up @@ -597,3 +630,18 @@ pub trait Frozen: boolean_struct::private::Boolean {}

impl Frozen for boolean_struct::True {}
impl Frozen for boolean_struct::False {}

mod tests {
#[test]
fn test_compare_op_matches() {
use super::CompareOp;
use std::cmp::Ordering;

assert!(CompareOp::Eq.matches(Ordering::Equal));
assert!(CompareOp::Ne.matches(Ordering::Less));
assert!(CompareOp::Ge.matches(Ordering::Greater));
assert!(CompareOp::Gt.matches(Ordering::Greater));
assert!(CompareOp::Le.matches(Ordering::Equal));
assert!(CompareOp::Lt.matches(Ordering::Less));
}
}