Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py/PyAny: remove PartialEq impl and add is() #2183

Merged
merged 1 commit into from Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Expand Up @@ -35,7 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133)
- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159)
- Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181)
- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`.
- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`. [#2175](https://github.com/PyO3/pyo3/pull/2175)
- Add `Py::is` and `PyAny::is` methods to check for object identity. [#2183](https://github.com/PyO3/pyo3/pull/2183)

### Changed

Expand Down Expand Up @@ -81,7 +82,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed

- Remove all functionality deprecated in PyO3 0.14. [#2007](https://github.com/PyO3/pyo3/pull/2007)
- Remove `Default` impl for `PyMethodDef` [2166](https://github.com/PyO3/pyo3/pull/2166)
- Remove `Default` impl for `PyMethodDef`. [#2166](https://github.com/PyO3/pyo3/pull/2166)
- Remove `PartialEq` impl for `Py` and `PyAny` (use the new `is()` instead). [#2183](https://github.com/PyO3/pyo3/pull/2183)

### Fixed

Expand Down
13 changes: 13 additions & 0 deletions guide/src/migration.md
Expand Up @@ -62,6 +62,19 @@ impl MyClass {
}
```

### Removed `PartialEq` for object wrappers

The Python object wrappers `Py` and `PyAny` had implementations of `PartialEq`
so that `object_a == object_b` would compare the Python objects for pointer
equality, which corresponds to the `is` operator, not the `==` operator in
Python. This has been removed in favor of a new method: use
`object_a.is(object_b)`. This also has the advantage of not requiring the same
wrapper type for `object_a` and `object_b`; you can now directly compare a
`Py<T>` with a `&PyAny` without having to convert.

To check for Python object equality (the Python `==` operator), use the new
method `eq()`.

### Container magic methods now match Python behavior

In PyO3 0.15, `__getitem__`, `__setitem__` and `__delitem__` in `#[pymethods]` would generate only the _mapping_ implementation for a `#[pyclass]`. To match the Python behavior, these methods now generate both the _mapping_ **and** _sequence_ implementations.
Expand Down
2 changes: 1 addition & 1 deletion src/conversion.rs
Expand Up @@ -561,7 +561,7 @@ mod tests {
Python::with_gil(|py| {
let list = PyList::new(py, &[1, 2, 3]);
let val = unsafe { <PyList as PyTryFrom>::try_from_unchecked(list.as_ref()) };
assert_eq!(list, val);
assert!(list.is(val));
});
}

Expand Down
35 changes: 22 additions & 13 deletions src/err/mod.rs
Expand Up @@ -187,7 +187,7 @@ impl PyErr {
///
/// Python::with_gil(|py| {
/// let err: PyErr = PyTypeError::new_err(("some type error",));
/// assert_eq!(err.get_type(py), PyType::new::<PyTypeError>(py));
/// assert!(err.get_type(py).is(PyType::new::<PyTypeError>(py)));
/// });
/// ```
pub fn get_type<'py>(&'py self, py: Python<'py>) -> &'py PyType {
Expand Down Expand Up @@ -231,7 +231,7 @@ impl PyErr {
///
/// Python::with_gil(|py| {
/// let err = PyTypeError::new_err(("some type error",));
/// assert_eq!(err.traceback(py), None);
/// assert!(err.traceback(py).is_none());
/// });
/// ```
pub fn traceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> {
Expand Down Expand Up @@ -469,9 +469,12 @@ impl PyErr {
/// Python::with_gil(|py| {
/// let err: PyErr = PyTypeError::new_err(("some type error",));
/// let err_clone = err.clone_ref(py);
/// assert_eq!(err.get_type(py), err_clone.get_type(py));
/// assert_eq!(err.value(py), err_clone.value(py));
/// assert_eq!(err.traceback(py), err_clone.traceback(py));
/// assert!(err.get_type(py).is(err_clone.get_type(py)));
/// assert!(err.value(py).is(err_clone.value(py)));
/// match err.traceback(py) {
/// None => assert!(err_clone.traceback(py).is_none()),
/// Some(tb) => assert!(err_clone.traceback(py).unwrap().is(tb)),
/// }
/// });
/// ```
#[inline]
Expand Down Expand Up @@ -706,7 +709,7 @@ fn exceptions_must_derive_from_base_exception(py: Python) -> PyErr {
mod tests {
use super::PyErrState;
use crate::exceptions;
use crate::{PyErr, Python};
use crate::{AsPyPointer, PyErr, Python};

#[test]
fn no_error() {
Expand Down Expand Up @@ -857,16 +860,22 @@ mod tests {
fn deprecations() {
let err = exceptions::PyValueError::new_err("an error");
Python::with_gil(|py| {
assert_eq!(err.ptype(py), err.get_type(py));
assert_eq!(err.pvalue(py), err.value(py));
assert_eq!(err.instance(py), err.value(py));
assert_eq!(err.ptraceback(py), err.traceback(py));
assert_eq!(err.ptype(py).as_ptr(), err.get_type(py).as_ptr());
assert_eq!(err.pvalue(py).as_ptr(), err.value(py).as_ptr());
assert_eq!(err.instance(py).as_ptr(), err.value(py).as_ptr());
assert_eq!(
err.ptraceback(py).map(|t| t.as_ptr()),
err.traceback(py).map(|t| t.as_ptr())
);

assert_eq!(
err.clone_ref(py).into_instance(py).as_ref(py),
err.value(py)
err.clone_ref(py).into_instance(py).as_ref(py).as_ptr(),
err.value(py).as_ptr()
);
assert_eq!(
PyErr::from_instance(err.value(py)).value(py).as_ptr(),
err.value(py).as_ptr()
);
assert_eq!(PyErr::from_instance(err.value(py)).value(py), err.value(py));
});
}
}
2 changes: 1 addition & 1 deletion src/impl_/extract_argument.rs
Expand Up @@ -102,7 +102,7 @@ pub fn from_py_with_with_default<'py, T>(
#[doc(hidden)]
#[cold]
pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr {
if error.get_type(py) == PyTypeError::type_object(py) {
if error.get_type(py).is(PyTypeError::type_object(py)) {
let remapped_error =
PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py)));
remapped_error.set_cause(py, error.cause(py));
Expand Down
16 changes: 9 additions & 7 deletions src/instance.rs
Expand Up @@ -463,6 +463,15 @@ where
}

impl<T> Py<T> {
/// Returns whether `self` and `other` point to the same object. To compare
/// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq).
///
/// This is equivalent to the Python expression `self is other`.
#[inline]
pub fn is<U: AsPyPointer>(&self, o: &U) -> bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this inherit the #[inline] attribute from the PartialEq method? (Similarly for the method on PyAny.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think probably makes sense for this to be #[inline].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

self.as_ptr() == o.as_ptr()
}

/// Gets the reference count of the `ffi::PyObject` pointer.
#[inline]
pub fn get_refcnt(&self, _py: Python) -> isize {
Expand Down Expand Up @@ -829,13 +838,6 @@ where
}
}

impl<T> PartialEq for Py<T> {
#[inline]
fn eq(&self, o: &Py<T>) -> bool {
self.0 == o.0
}
}

/// If the GIL is held this increments `self`'s reference count.
/// Otherwise this registers the [`Py`]`<T>` instance to have its reference count
/// incremented the next time PyO3 acquires the GIL.
Expand Down
9 changes: 9 additions & 0 deletions src/types/any.rs
Expand Up @@ -87,6 +87,15 @@ impl PyAny {
<T as PyTryFrom>::try_from(self)
}

/// Returns whether `self` and `other` point to the same object. To compare
/// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq).
///
/// This is equivalent to the Python expression `self is other`.
#[inline]
pub fn is<T: AsPyPointer>(&self, other: &T) -> bool {
self.as_ptr() == other.as_ptr()
}

/// Determines whether this object has the given attribute.
///
/// This is equivalent to the Python expression `hasattr(self, attr_name)`.
Expand Down
4 changes: 2 additions & 2 deletions src/types/boolobject.rs
Expand Up @@ -69,7 +69,7 @@ mod tests {
assert!(PyBool::new(py, true).is_true());
let t: &PyAny = PyBool::new(py, true).into();
assert!(t.extract::<bool>().unwrap());
assert_eq!(true.to_object(py), PyBool::new(py, true).into());
assert!(true.to_object(py).is(PyBool::new(py, true)));
});
}

Expand All @@ -79,7 +79,7 @@ mod tests {
assert!(!PyBool::new(py, false).is_true());
let t: &PyAny = PyBool::new(py, false).into();
assert!(!t.extract::<bool>().unwrap());
assert_eq!(false.to_object(py), PyBool::new(py, false).into());
assert!(false.to_object(py).is(PyBool::new(py, false)));
});
}
}
8 changes: 4 additions & 4 deletions src/types/dict.rs
Expand Up @@ -387,7 +387,7 @@ mod tests {
Python::with_gil(|py| {
let dict = [(7, 32)].into_py_dict(py);
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, dict.get_item(8i32));
assert!(dict.get_item(8i32).is_none());
let map: HashMap<i32, i32> = [(7, 32)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<i32, i32> = [(7, 32)].iter().cloned().collect();
Expand Down Expand Up @@ -426,7 +426,7 @@ mod tests {

let ndict = dict.copy().unwrap();
assert_eq!(32, ndict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, ndict.get_item(8i32));
assert!(ndict.get_item(8i32).is_none());
});
}

Expand Down Expand Up @@ -464,7 +464,7 @@ mod tests {
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert_eq!(None, dict.get_item(8i32));
assert!(dict.get_item(8i32).is_none());
});
}

Expand Down Expand Up @@ -527,7 +527,7 @@ mod tests {
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();
assert!(dict.del_item(7i32).is_ok());
assert_eq!(0, dict.len());
assert_eq!(None, dict.get_item(7i32));
assert!(dict.get_item(7i32).is_none());
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/types/iterator.rs
Expand Up @@ -213,7 +213,7 @@ def fibonacci(target):
Python::with_gil(|py| {
let obj: Py<PyAny> = vec![10, 20].to_object(py).as_ref(py).iter().unwrap().into();
let iter: &PyIterator = PyIterator::try_from(obj.as_ref(py)).unwrap();
assert_eq!(obj, iter.into());
assert!(obj.is(iter));
});
}

Expand Down
9 changes: 0 additions & 9 deletions src/types/mod.rs
Expand Up @@ -64,15 +64,6 @@ macro_rules! pyobject_native_type_base(
unsafe { $crate::PyObject::from_borrowed_ptr(py, self.as_ptr()) }
}
}

impl<$($generics,)*> ::std::cmp::PartialEq for $name {
#[inline]
fn eq(&self, o: &$name) -> bool {
use $crate::AsPyPointer;

self.as_ptr() == o.as_ptr()
}
}
};
);

Expand Down
4 changes: 2 additions & 2 deletions src/types/sequence.rs
Expand Up @@ -725,11 +725,11 @@ mod tests {
let seq = ob.cast_as::<PySequence>(py).unwrap();
let rep_seq = seq.in_place_repeat(3).unwrap();
assert_eq!(6, seq.len().unwrap());
assert_eq!(seq, rep_seq);
assert!(seq.is(rep_seq));

let conc_seq = seq.in_place_concat(seq).unwrap();
assert_eq!(12, seq.len().unwrap());
assert_eq!(seq, conc_seq);
assert!(seq.is(conc_seq));
});
}

Expand Down
6 changes: 3 additions & 3 deletions src/types/string.rs
Expand Up @@ -504,7 +504,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs1(b"f\xfe"));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-8' codec can't decode byte 0xfe in position 1"));
Expand Down Expand Up @@ -546,7 +546,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs2(&[0xff22, 0xd800]));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-16' codec can't decode bytes in position 0-3"));
Expand Down Expand Up @@ -585,7 +585,7 @@ mod tests {
let data = unsafe { s.data().unwrap() };
assert_eq!(data, PyStringData::Ucs4(&[0x20000, 0xd800]));
let err = data.to_string(py).unwrap_err();
assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py));
assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py)));
assert!(err
.to_string()
.contains("'utf-32' codec can't decode bytes in position 0-7"));
Expand Down
10 changes: 6 additions & 4 deletions tests/test_sequence.rs
Expand Up @@ -279,10 +279,12 @@ fn test_generic_list_set() {
let list = PyCell::new(py, GenericList { items: vec![] }).unwrap();

py_run!(py, list, "list.items = [1, 2, 3]");
assert_eq!(
list.borrow().items,
vec![1.to_object(py), 2.to_object(py), 3.to_object(py)]
);
assert!(list
.borrow()
.items
.iter()
.zip(&[1u32, 2, 3])
.all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap()));
}

#[pyclass]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_sequence_pyproto.rs
Expand Up @@ -263,10 +263,12 @@ fn test_generic_list_set() {
let list = PyCell::new(py, GenericList { items: vec![] }).unwrap();

py_run!(py, list, "list.items = [1, 2, 3]");
assert_eq!(
list.borrow().items,
vec![1.to_object(py), 2.to_object(py), 3.to_object(py)]
);
assert!(list
.borrow()
.items
.iter()
.zip(&[1u32, 2, 3])
.all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap()));
}

#[pyclass]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_serde.rs
Expand Up @@ -59,12 +59,12 @@ mod test_serde {

#[test]
fn test_deserialize() {
let serialized = r#"{"username": "danya", "friends":
let serialized = r#"{"username": "danya", "friends":
[{"username": "friend", "group": {"name": "danya's friends"}, "friends": []}]}"#;
let user: User = serde_json::from_str(serialized).expect("failed to deserialize");

assert_eq!(user.username, "danya");
assert_eq!(user.group, None);
assert!(user.group.is_none());
assert_eq!(user.friends.len(), 1usize);
let friend = user.friends.get(0).unwrap();

Expand Down