Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PyTryFrom for PyMapping and PySequence to more accurately check types #2477

Merged
merged 15 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PyCapsule::name` now returns `PyResult<Option<&CStr>>` instead of `&CStr`.
- `FromPyObject::extract` now raises an error if source type is `PyString` and target type is `Vec<T>`. [#2500](https://github.com/PyO3/pyo3/pull/2500)
- Only allow each `#[pymodule]` to be initialized once. [#2523](https://github.com/PyO3/pyo3/pull/2523)
- `pyo3_build_config::add_extension_module_link_args()` now also emits linker arguments for `wasm32-unknown-emscripten`. [#2538](https://github.com/PyO3/pyo3/pull/2538)
- `pyo3_build_config::add_extension_module_link_args()` now also emits linker arguments for `wasm32-unknown-emscripten`. [#2500](https://github.com/PyO3/pyo3/pull/2500)
- Downcasting (`PyTryFrom`) behavior has changed for `PySequence` and `PyMapping`: classes are now required to inherit from (or register with) the corresponding Python standard library abstract base class. See the [migration guide](https://pyo3.rs/latest/migration.html) for information on fixing broken downcasts. [#2477](https://github.com/PyO3/pyo3/pull/2477)

### Removed

Expand Down
52 changes: 52 additions & 0 deletions guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,58 @@ For a detailed list of all changes, see the [CHANGELOG](changelog.md).

## from 0.16.* to 0.17

### Type checks have been changed for `PyMapping` and `PySequence` types

Previously the type checks for `PyMapping` and `PySequence` (implemented in `PyTryFrom`)
used the Python C-API functions `PyMapping_Check` and `PySequence_Check`.
Unfortunately these functions are not sufficient for distinguishing such types,
leading to inconsistent behavior (see
[pyo3/pyo3#2072](https://github.com/PyO3/pyo3/issues/2072)).

PyO3 0.17 changes these downcast checks to explicityly test if the type is a
subclass of the corresponding abstract base class `collections.abc.Mapping` or
`collections.abc.Sequence`. Note this requires calling into Python, which may
incur a performance penalty over the previous method. If this performance
penatly is a problem, you may be able to perform your own checks and use
`try_from_unchecked` (unsafe).

Another side-effect is that a pyclass defined in Rust with PyO3 will need to
be _registered_ with the corresponding Python abstract base class for
downcasting to succeed. `PySequence::register` and `PyMapping:register` have
been added to make it easy to do this from Rust code. These are equivalent to
calling `collections.abc.Mapping.register(MappingPyClass)` or
`collections.abc.Sequence.register(SequencePyClass)` from Python.

For example, for a mapping class defined in Rust:
```rust,compile_fail
use pyo3::prelude::*;
use std::collections::HashMap;

#[pyclass(mapping)]
struct Mapping {
index: HashMap<String, usize>,
}

#[pymethods]
impl Mapping {
#[new]
fn new(elements: Option<&PyList>) -> PyResult<Self> {
// ...
// truncated implementation of this mapping pyclass - basically a wrapper around a HashMap
}

```

You must register the class with `collections.abc.Mapping` before the downcast will work:
```rust,compile_fail
let m = Py::new(py, Mapping { index }).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_err());
PyMapping::register::<Mapping>(py).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
```

Note that this requirement may go away in the future when a pyclass is able to inherit from the abstract base class directly (see [pyo3/pyo3#991](https://github.com/PyO3/pyo3/issues/991)).

### The `multiple-pymethods` feature now requires Rust 1.62

Due to limitations in the `inventory` crate which the `multiple-pymethods` feature depends on, this feature now
Expand Down
47 changes: 40 additions & 7 deletions src/types/mapping.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::err::{PyDowncastError, PyErr, PyResult};
use crate::types::{PyAny, PySequence};
use crate::{ffi, AsPyPointer, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject};
use crate::once_cell::GILOnceCell;
use crate::type_object::PyTypeInfo;
use crate::types::{PyAny, PySequence, PyType};
use crate::{
ffi, AsPyPointer, IntoPy, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject,
};

static MAPPING_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the mapping protocol.
#[repr(transparent)]
Expand Down Expand Up @@ -102,18 +108,45 @@ impl PyMapping {
.from_owned_ptr_or_err(ffi::PyMapping_Items(self.as_ptr()))
}
}

/// Register a pyclass as a subclass of `collections.abc.Mapping` (from the Python standard
/// library). This is equvalent to `collections.abc.Mapping.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PyMapping`.
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_mapping_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}

fn get_mapping_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
MAPPING_ABC
.get_or_init(py, || {
Ok(py
.import("collections.abc")?
.getattr("Mapping")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref()
.map_or_else(|e| Err(e.clone_ref(py)), |t| Ok(t.as_ref(py)))
}

impl<'v> PyTryFrom<'v> for PyMapping {
/// Downcasting to `PyMapping` requires the concrete class to be a subclass (or registered
/// subclass) of `collections.abc.Mapping` (from the Python standard library) - i.e.
/// `isinstance(<class>, collections.abc.Mapping) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> {
let value = value.into();
unsafe {
if ffi::PyMapping_Check(value.as_ptr()) != 0 {
Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value))
} else {
Err(PyDowncastError::new(value, "Mapping"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_mapping_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Mapping"))
}

#[inline]
Expand Down
45 changes: 38 additions & 7 deletions src/types/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
use crate::err::{self, PyDowncastError, PyErr, PyResult};
use crate::exceptions::PyValueError;
use crate::internal_tricks::get_ssize_index;
use crate::types::{PyAny, PyList, PyString, PyTuple};
use crate::once_cell::GILOnceCell;
use crate::type_object::PyTypeInfo;
use crate::types::{PyAny, PyList, PyString, PyTuple, PyType};
use crate::{ffi, PyNativeType, ToPyObject};
use crate::{AsPyPointer, IntoPyPointer, Py, Python};
use crate::{AsPyPointer, IntoPy, IntoPyPointer, Py, Python};
use crate::{FromPyObject, PyTryFrom};

static SEQUENCE_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the sequence protocol.
#[repr(transparent)]
pub struct PySequence(PyAny);
Expand Down Expand Up @@ -250,6 +254,15 @@ impl PySequence {
.from_owned_ptr_or_err(ffi::PySequence_Tuple(self.as_ptr()))
}
}

/// Register a pyclass as a subclass of `collections.abc.Sequence` (from the Python standard
/// library). This is equvalent to `collections.abc.Sequence.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PySequence`.
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_sequence_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}

#[inline]
Expand Down Expand Up @@ -289,16 +302,34 @@ where
Ok(v)
}

fn get_sequence_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
SEQUENCE_ABC
.get_or_init(py, || {
Ok(py
.import("collections.abc")?
.getattr("Sequence")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref()
.map_or_else(|e| Err(e.clone_ref(py)), |t| Ok(t.as_ref(py)))
}

impl<'v> PyTryFrom<'v> for PySequence {
/// Downcasting to `PySequence` requires the concrete class to be a subclass (or registered
/// subclass) of `collections.abc.Sequence` (from the Python standard library) - i.e.
/// `isinstance(<class>, collections.abc.Sequence) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
let value = value.into();
unsafe {
if ffi::PySequence_Check(value.as_ptr()) != 0 {
Ok(<PySequence as PyTryFrom>::try_from_unchecked(value))
} else {
Err(PyDowncastError::new(value, "Sequence"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_sequence_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PySequence as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Sequence"))
}

fn try_from_exact<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
Expand Down
3 changes: 3 additions & 0 deletions tests/test_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ fn mapping_is_not_sequence() {
index.insert("Foo".into(), 1);
index.insert("Bar".into(), 2);
let m = Py::new(py, Mapping { index }).unwrap();

PyMapping::register::<Mapping>(py).unwrap();

assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
assert!(m.as_ref(py).downcast::<PySequence>().is_err());
});
Expand Down
5 changes: 4 additions & 1 deletion tests/test_proto_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ impl Mapping {
#[test]
fn mapping() {
Python::with_gil(|py| {
PyMapping::register::<Mapping>(py).unwrap();

let inst = Py::new(
py,
Mapping {
Expand All @@ -218,7 +220,6 @@ fn mapping() {
)
.unwrap();

//
let mapping: &PyMapping = inst.as_ref(py).downcast().unwrap();

py_assert!(py, inst, "len(inst) == 0");
Expand Down Expand Up @@ -317,6 +318,8 @@ impl Sequence {
#[test]
fn sequence() {
Python::with_gil(|py| {
PySequence::register::<Sequence>(py).unwrap();

let inst = Py::new(py, Sequence { values: vec![] }).unwrap();

let sequence: &PySequence = inst.as_ref(py).downcast().unwrap();
Expand Down
21 changes: 20 additions & 1 deletion tests/test_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyList};
use pyo3::types::{IntoPyDict, PyList, PyMapping, PySequence};

use pyo3::py_run;

Expand Down Expand Up @@ -312,3 +312,22 @@ fn test_option_list_get() {
py_expect_exception!(py, list, "list[2]", PyIndexError);
});
}

#[test]
fn sequence_is_not_mapping() {
let gil = Python::acquire_gil();
let py = gil.python();

let list = PyCell::new(
py,
OptionList {
items: vec![Some(1), None],
},
)
.unwrap();

PySequence::register::<OptionList>(py).unwrap();

assert!(list.as_ref().downcast::<PyMapping>().is_err());
assert!(list.as_ref().downcast::<PySequence>().is_ok());
}