Skip to content

Commit

Permalink
add CompareOp::matches
Browse files Browse the repository at this point in the history
to easily implement `__richcmp__` as the result of a Rust `std::cmp::Ordering` comparison.
  • Loading branch information
birkenfeld committed Jun 21, 2022
1 parent 920fa93 commit e70a874
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423)
- Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428)
- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383)
- Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a
Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460)

### Changed

Expand Down
30 changes: 26 additions & 4 deletions guide/src/class/object.md
Expand Up @@ -128,15 +128,15 @@ impl Number {
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
This method will be called with a value of `CompareOp` depending on the operation.

```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
#
# #[pyclass]
# struct Number(i32);
#
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
Expand All @@ -152,6 +152,28 @@ impl Number {
}
```

If you obtain the result by comparing two Rust values, as in this example, you
can take a shortcut using `CompareOp::matches`:

```rust
use pyo3::class::basic::CompareOp;

# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.0.cmp(&other.0))
}
}
```

It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
the given `CompareOp`.

### Truthyness

We'll consider `Number` to be `True` if it is nonzero:
Expand Down Expand Up @@ -229,4 +251,4 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
[`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html
[`Hasher`]: https://doc.rust-lang.org/std/hash/trait.Hasher.html
[`DefaultHasher`]: https://doc.rust-lang.org/std/collections/hash_map/struct.DefaultHasher.html
[SipHash]: https://en.wikipedia.org/wiki/SipHash
[SipHash]: https://en.wikipedia.org/wiki/SipHash
8 changes: 6 additions & 2 deletions guide/src/class/protocols.md
Expand Up @@ -70,8 +70,11 @@ given signatures should be interpreted as follows:
<details>
<summary>Return type</summary>
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
If the `object` is not of the type specified in the signature, the generated code will
automatically `return NotImplemented`.
If the second argument `object` is not of the type specified in the
signature, the generated code will automatically `return NotImplemented`.

You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
to the requested comparison.
</details>

- `__getattr__(<self>, object) -> object`
Expand Down Expand Up @@ -611,3 +614,4 @@ For details, look at the `#[pymethods]` regarding GC methods.
[`PySequenceProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/sequence/trait.PySequenceProtocol.html
[`PyIterProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/iter/trait.PyIterProtocol.html
[`PySequence`]: {{#PYO3_DOCS_URL}}/pyo3/types/struct.PySequence.html
[`CompareOp::matches`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.CompareOp.html#method.matches
48 changes: 48 additions & 0 deletions src/pyclass.rs
Expand Up @@ -10,6 +10,7 @@ use crate::{
IntoPy, IntoPyPointer, PyCell, PyErr, PyMethodDefType, PyObject, PyResult, PyTypeInfo, Python,
};
use std::{
cmp::Ordering,
convert::TryInto,
ffi::{CStr, CString},
os::raw::{c_char, c_int, c_uint, c_void},
Expand Down Expand Up @@ -452,6 +453,7 @@ pub enum CompareOp {
}

impl CompareOp {
/// Conversion from the C enum.
pub fn from_raw(op: c_int) -> Option<Self> {
match op {
ffi::Py_LT => Some(CompareOp::Lt),
Expand All @@ -463,6 +465,37 @@ impl CompareOp {
_ => None,
}
}

/// Returns if a Rust [`std::cmp::Ordering`] matches this ordering query.
///
/// Usage example:
///
/// ```rust
/// # use pyo3::prelude::*;
/// # use pyo3::class::basic::CompareOp;
///
/// #[pyclass]
/// struct Size {
/// size: usize
/// }
///
/// #[pymethods]
/// impl Size {
/// fn __richcmp__(&self, other: &Size, op: CompareOp) -> bool {
/// op.matches(self.size.cmp(&other.size))
/// }
/// }
/// ```
pub fn matches(&self, result: Ordering) -> bool {
match self {
CompareOp::Eq => result == Ordering::Equal,
CompareOp::Ne => result != Ordering::Equal,
CompareOp::Lt => result == Ordering::Less,
CompareOp::Le => result != Ordering::Greater,
CompareOp::Gt => result == Ordering::Greater,
CompareOp::Ge => result != Ordering::Less,
}
}
}

/// Output of `__next__` which can either `yield` the next value in the iteration, or
Expand Down Expand Up @@ -597,3 +630,18 @@ pub trait Frozen: boolean_struct::private::Boolean {}

impl Frozen for boolean_struct::True {}
impl Frozen for boolean_struct::False {}

mod tests {
#[test]
fn test_compare_op_matches() {
use super::CompareOp;
use std::cmp::Ordering;

assert!(CompareOp::Eq.matches(Ordering::Equal));
assert!(CompareOp::Ne.matches(Ordering::Less));
assert!(CompareOp::Ge.matches(Ordering::Greater));
assert!(CompareOp::Gt.matches(Ordering::Greater));
assert!(CompareOp::Le.matches(Ordering::Equal));
assert!(CompareOp::Lt.matches(Ordering::Less));
}
}

0 comments on commit e70a874

Please sign in to comment.