Skip to content

Commit

Permalink
Update PyTryFrom for PyMapping and PySequence to more accurately chec…
Browse files Browse the repository at this point in the history
…k types (#2477)

Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com>
  • Loading branch information
aganders3 and davidhewitt committed Aug 10, 2022
1 parent 8181160 commit 5d88e1d
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 17 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PyCapsule::name` now returns `PyResult<Option<&CStr>>` instead of `&CStr`.
- `FromPyObject::extract` now raises an error if source type is `PyString` and target type is `Vec<T>`. [#2500](https://github.com/PyO3/pyo3/pull/2500)
- Only allow each `#[pymodule]` to be initialized once. [#2523](https://github.com/PyO3/pyo3/pull/2523)
- `pyo3_build_config::add_extension_module_link_args()` now also emits linker arguments for `wasm32-unknown-emscripten`. [#2538](https://github.com/PyO3/pyo3/pull/2538)
- `pyo3_build_config::add_extension_module_link_args()` now also emits linker arguments for `wasm32-unknown-emscripten`. [#2500](https://github.com/PyO3/pyo3/pull/2500)
- Downcasting (`PyTryFrom`) behavior has changed for `PySequence` and `PyMapping`: classes are now required to inherit from (or register with) the corresponding Python standard library abstract base class. See the [migration guide](https://pyo3.rs/latest/migration.html) for information on fixing broken downcasts. [#2477](https://github.com/PyO3/pyo3/pull/2477)

### Removed

Expand Down
52 changes: 52 additions & 0 deletions guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,58 @@ For a detailed list of all changes, see the [CHANGELOG](changelog.md).

## from 0.16.* to 0.17

### Type checks have been changed for `PyMapping` and `PySequence` types

Previously the type checks for `PyMapping` and `PySequence` (implemented in `PyTryFrom`)
used the Python C-API functions `PyMapping_Check` and `PySequence_Check`.
Unfortunately these functions are not sufficient for distinguishing such types,
leading to inconsistent behavior (see
[pyo3/pyo3#2072](https://github.com/PyO3/pyo3/issues/2072)).

PyO3 0.17 changes these downcast checks to explicityly test if the type is a
subclass of the corresponding abstract base class `collections.abc.Mapping` or
`collections.abc.Sequence`. Note this requires calling into Python, which may
incur a performance penalty over the previous method. If this performance
penatly is a problem, you may be able to perform your own checks and use
`try_from_unchecked` (unsafe).

Another side-effect is that a pyclass defined in Rust with PyO3 will need to
be _registered_ with the corresponding Python abstract base class for
downcasting to succeed. `PySequence::register` and `PyMapping:register` have
been added to make it easy to do this from Rust code. These are equivalent to
calling `collections.abc.Mapping.register(MappingPyClass)` or
`collections.abc.Sequence.register(SequencePyClass)` from Python.

For example, for a mapping class defined in Rust:
```rust,compile_fail
use pyo3::prelude::*;
use std::collections::HashMap;
#[pyclass(mapping)]
struct Mapping {
index: HashMap<String, usize>,
}
#[pymethods]
impl Mapping {
#[new]
fn new(elements: Option<&PyList>) -> PyResult<Self> {
// ...
// truncated implementation of this mapping pyclass - basically a wrapper around a HashMap
}
```

You must register the class with `collections.abc.Mapping` before the downcast will work:
```rust,compile_fail
let m = Py::new(py, Mapping { index }).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_err());
PyMapping::register::<Mapping>(py).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
```

Note that this requirement may go away in the future when a pyclass is able to inherit from the abstract base class directly (see [pyo3/pyo3#991](https://github.com/PyO3/pyo3/issues/991)).

### The `multiple-pymethods` feature now requires Rust 1.62

Due to limitations in the `inventory` crate which the `multiple-pymethods` feature depends on, this feature now
Expand Down
47 changes: 40 additions & 7 deletions src/types/mapping.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::err::{PyDowncastError, PyErr, PyResult};
use crate::types::{PyAny, PySequence};
use crate::{ffi, AsPyPointer, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject};
use crate::once_cell::GILOnceCell;
use crate::type_object::PyTypeInfo;
use crate::types::{PyAny, PySequence, PyType};
use crate::{
ffi, AsPyPointer, IntoPy, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject,
};

static MAPPING_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the mapping protocol.
#[repr(transparent)]
Expand Down Expand Up @@ -102,18 +108,45 @@ impl PyMapping {
.from_owned_ptr_or_err(ffi::PyMapping_Items(self.as_ptr()))
}
}

/// Register a pyclass as a subclass of `collections.abc.Mapping` (from the Python standard
/// library). This is equvalent to `collections.abc.Mapping.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PyMapping`.
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_mapping_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}

fn get_mapping_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
MAPPING_ABC
.get_or_init(py, || {
Ok(py
.import("collections.abc")?
.getattr("Mapping")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref()
.map_or_else(|e| Err(e.clone_ref(py)), |t| Ok(t.as_ref(py)))
}

impl<'v> PyTryFrom<'v> for PyMapping {
/// Downcasting to `PyMapping` requires the concrete class to be a subclass (or registered
/// subclass) of `collections.abc.Mapping` (from the Python standard library) - i.e.
/// `isinstance(<class>, collections.abc.Mapping) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> {
let value = value.into();
unsafe {
if ffi::PyMapping_Check(value.as_ptr()) != 0 {
Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value))
} else {
Err(PyDowncastError::new(value, "Mapping"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_mapping_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Mapping"))
}

#[inline]
Expand Down
45 changes: 38 additions & 7 deletions src/types/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
use crate::err::{self, PyDowncastError, PyErr, PyResult};
use crate::exceptions::PyValueError;
use crate::internal_tricks::get_ssize_index;
use crate::types::{PyAny, PyList, PyString, PyTuple};
use crate::once_cell::GILOnceCell;
use crate::type_object::PyTypeInfo;
use crate::types::{PyAny, PyList, PyString, PyTuple, PyType};
use crate::{ffi, PyNativeType, ToPyObject};
use crate::{AsPyPointer, IntoPyPointer, Py, Python};
use crate::{AsPyPointer, IntoPy, IntoPyPointer, Py, Python};
use crate::{FromPyObject, PyTryFrom};

static SEQUENCE_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the sequence protocol.
#[repr(transparent)]
pub struct PySequence(PyAny);
Expand Down Expand Up @@ -250,6 +254,15 @@ impl PySequence {
.from_owned_ptr_or_err(ffi::PySequence_Tuple(self.as_ptr()))
}
}

/// Register a pyclass as a subclass of `collections.abc.Sequence` (from the Python standard
/// library). This is equvalent to `collections.abc.Sequence.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PySequence`.
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_sequence_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}

#[inline]
Expand Down Expand Up @@ -289,16 +302,34 @@ where
Ok(v)
}

fn get_sequence_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
SEQUENCE_ABC
.get_or_init(py, || {
Ok(py
.import("collections.abc")?
.getattr("Sequence")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref()
.map_or_else(|e| Err(e.clone_ref(py)), |t| Ok(t.as_ref(py)))
}

impl<'v> PyTryFrom<'v> for PySequence {
/// Downcasting to `PySequence` requires the concrete class to be a subclass (or registered
/// subclass) of `collections.abc.Sequence` (from the Python standard library) - i.e.
/// `isinstance(<class>, collections.abc.Sequence) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
let value = value.into();
unsafe {
if ffi::PySequence_Check(value.as_ptr()) != 0 {
Ok(<PySequence as PyTryFrom>::try_from_unchecked(value))
} else {
Err(PyDowncastError::new(value, "Sequence"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_sequence_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PySequence as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Sequence"))
}

fn try_from_exact<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
Expand Down
3 changes: 3 additions & 0 deletions tests/test_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ fn mapping_is_not_sequence() {
index.insert("Foo".into(), 1);
index.insert("Bar".into(), 2);
let m = Py::new(py, Mapping { index }).unwrap();

PyMapping::register::<Mapping>(py).unwrap();

assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
assert!(m.as_ref(py).downcast::<PySequence>().is_err());
});
Expand Down
5 changes: 4 additions & 1 deletion tests/test_proto_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ impl Mapping {
#[test]
fn mapping() {
Python::with_gil(|py| {
PyMapping::register::<Mapping>(py).unwrap();

let inst = Py::new(
py,
Mapping {
Expand All @@ -218,7 +220,6 @@ fn mapping() {
)
.unwrap();

//
let mapping: &PyMapping = inst.as_ref(py).downcast().unwrap();

py_assert!(py, inst, "len(inst) == 0");
Expand Down Expand Up @@ -317,6 +318,8 @@ impl Sequence {
#[test]
fn sequence() {
Python::with_gil(|py| {
PySequence::register::<Sequence>(py).unwrap();

let inst = Py::new(py, Sequence { values: vec![] }).unwrap();

let sequence: &PySequence = inst.as_ref(py).downcast().unwrap();
Expand Down
21 changes: 20 additions & 1 deletion tests/test_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyList};
use pyo3::types::{IntoPyDict, PyList, PyMapping, PySequence};

use pyo3::py_run;

Expand Down Expand Up @@ -312,3 +312,22 @@ fn test_option_list_get() {
py_expect_exception!(py, list, "list[2]", PyIndexError);
});
}

#[test]
fn sequence_is_not_mapping() {
let gil = Python::acquire_gil();
let py = gil.python();

let list = PyCell::new(
py,
OptionList {
items: vec![Some(1), None],
},
)
.unwrap();

PySequence::register::<OptionList>(py).unwrap();

assert!(list.as_ref().downcast::<PyMapping>().is_err());
assert!(list.as_ref().downcast::<PySequence>().is_ok());
}

0 comments on commit 5d88e1d

Please sign in to comment.