Skip to content

Commit

Permalink
Updates based on review
Browse files Browse the repository at this point in the history
* Rename register_abc_subclass -> register
* Store PyResult in MAPPING_ABC and SEQUENCE_ABC statics
  • Loading branch information
aganders3 committed Aug 9, 2022
1 parent eb87b2e commit aea20f6
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 47 deletions.
9 changes: 4 additions & 5 deletions guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ penatly is a problem, you may be able to perform your own checks and use

Another side-effect is that a pyclass defined in Rust with PyO3 will need to
be _registered_ with the corresponding Python abstract base class for
downcasting to succeed. `PySequence::register_abc_subclass` and
`PyMapping:register_abc_subclass` have been added to make it easy to do this
from Rust code. These are equivalent to calling
`collections.abc.Mapping.register(MappingPyClass)` or
downcasting to succeed. `PySequence::register` and `PyMapping:register` have
been added to make it easy to do this from Rust code. These are equivalent to
calling `collections.abc.Mapping.register(MappingPyClass)` or
`collections.abc.Sequence.register(SequencePyClass)` from Python.

For example, for a mapping class defined in Rust:
Expand All @@ -52,7 +51,7 @@ You must register the class with `collections.abc.Mapping` before the downcast w
```rust,compile_fail
let m = Py::new(py, Mapping { index }).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_err());
PyMapping::register_abc_subclass::<Mapping>(py).unwrap();
PyMapping::register::<Mapping>(py).unwrap();
assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
```

Expand Down
39 changes: 20 additions & 19 deletions src/types/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
ffi, AsPyPointer, IntoPy, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject,
};

static MAPPING_ABC: GILOnceCell<Py<PyType>> = GILOnceCell::new();
static MAPPING_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the mapping protocol.
#[repr(transparent)]
Expand Down Expand Up @@ -112,25 +112,25 @@ impl PyMapping {
/// Register a pyclass as a subclass of `collections.abc.Mapping` (from the Python standard
/// library). This is equvalent to `collections.abc.Mapping.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PyMapping`.
pub fn register_abc_subclass<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_mapping_abc(py).call_method1("register", (ty,))?;
get_mapping_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}

fn get_mapping_abc(py: Python<'_>) -> &PyType {
fn get_mapping_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
MAPPING_ABC
.get_or_init(py, || {
py.import("collections.abc")
.expect("coud not import 'collections.abc'")
.getattr("Mapping")
.expect("coud not access 'Mapping' from 'collections.abc'")
.downcast::<PyType>()
.expect("could not access 'collections.abc.Mapping'")
.into_py(py)
Ok(py
.import("collections.abc")?
.getattr("Mapping")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref(py)
.as_ref()
.map(|t| t.as_ref(py))
.map_err(|e| e.clone_ref(py))
}

impl<'v> PyTryFrom<'v> for PyMapping {
Expand All @@ -139,14 +139,15 @@ impl<'v> PyTryFrom<'v> for PyMapping {
/// `isinstance(<class>, collections.abc.Mapping) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> {
let value = value.into();
if value
.is_instance(get_mapping_abc(value.py()))
.unwrap_or(false)
{
unsafe { Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value)) }
} else {
Err(PyDowncastError::new(value, "Mapping"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_mapping_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Mapping"))
}

#[inline]
Expand Down
39 changes: 20 additions & 19 deletions src/types/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{ffi, PyNativeType, ToPyObject};
use crate::{AsPyPointer, IntoPy, IntoPyPointer, Py, Python};
use crate::{FromPyObject, PyTryFrom};

static SEQUENCE_ABC: GILOnceCell<Py<PyType>> = GILOnceCell::new();
static SEQUENCE_ABC: GILOnceCell<PyResult<Py<PyType>>> = GILOnceCell::new();

/// Represents a reference to a Python object supporting the sequence protocol.
#[repr(transparent)]
Expand Down Expand Up @@ -258,9 +258,9 @@ impl PySequence {
/// Register a pyclass as a subclass of `collections.abc.Sequence` (from the Python standard
/// library). This is equvalent to `collections.abc.Sequence.register(T)` in Python.
/// This registration is required for a pyclass to be downcastable from `PyAny` to `PySequence`.
pub fn register_abc_subclass<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
pub fn register<T: PyTypeInfo>(py: Python<'_>) -> PyResult<()> {
let ty = T::type_object(py);
get_sequence_abc(py).call_method1("register", (ty,))?;
get_sequence_abc(py)?.call_method1("register", (ty,))?;
Ok(())
}
}
Expand Down Expand Up @@ -302,18 +302,18 @@ where
Ok(v)
}

fn get_sequence_abc(py: Python<'_>) -> &PyType {
fn get_sequence_abc(py: Python<'_>) -> Result<&PyType, PyErr> {
SEQUENCE_ABC
.get_or_init(py, || {
py.import("collections.abc")
.expect("coud not import 'collections.abc'")
.getattr("Sequence")
.expect("coud not access 'Sequence' from 'collections.abc'")
.downcast::<PyType>()
.expect("could not access 'collections.abc.Sequence'")
.into_py(py)
Ok(py
.import("collections.abc")?
.getattr("Sequence")?
.downcast::<PyType>()?
.into_py(py))
})
.as_ref(py)
.as_ref()
.map(|t| t.as_ref(py))
.map_err(|e| e.clone_ref(py))
}

impl<'v> PyTryFrom<'v> for PySequence {
Expand All @@ -322,14 +322,15 @@ impl<'v> PyTryFrom<'v> for PySequence {
/// `isinstance(<class>, collections.abc.Sequence) == True`.
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
let value = value.into();
if value
.is_instance(get_sequence_abc(value.py()))
.unwrap_or(false)
{
unsafe { Ok(<PySequence as PyTryFrom>::try_from_unchecked(value)) }
} else {
Err(PyDowncastError::new(value, "Sequence"))

// TODO: surface specific errors in this chain to the user
if let Ok(abc) = get_sequence_abc(value.py()) {
if value.is_instance(abc).unwrap_or(false) {
unsafe { return Ok(<PySequence as PyTryFrom>::try_from_unchecked(value)) }
}
}

Err(PyDowncastError::new(value, "Sequence"))
}

fn try_from_exact<V: Into<&'v PyAny>>(value: V) -> Result<&'v PySequence, PyDowncastError<'v>> {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn mapping_is_not_sequence() {
index.insert("Bar".into(), 2);
let m = Py::new(py, Mapping { index }).unwrap();

PyMapping::register_abc_subclass::<Mapping>(py).unwrap();
PyMapping::register::<Mapping>(py).unwrap();

assert!(m.as_ref(py).downcast::<PyMapping>().is_ok());
assert!(m.as_ref(py).downcast::<PySequence>().is_err());
Expand Down
4 changes: 2 additions & 2 deletions tests/test_proto_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl Mapping {
#[test]
fn mapping() {
Python::with_gil(|py| {
PyMapping::register_abc_subclass::<Mapping>(py).unwrap();
PyMapping::register::<Mapping>(py).unwrap();

let inst = Py::new(
py,
Expand Down Expand Up @@ -318,7 +318,7 @@ impl Sequence {
#[test]
fn sequence() {
Python::with_gil(|py| {
PySequence::register_abc_subclass::<Sequence>(py).unwrap();
PySequence::register::<Sequence>(py).unwrap();

let inst = Py::new(py, Sequence { values: vec![] }).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion tests/test_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ fn sequence_is_not_mapping() {
)
.unwrap();

PySequence::register_abc_subclass::<OptionList>(py).unwrap();
PySequence::register::<OptionList>(py).unwrap();

assert!(list.as_ref().downcast::<PyMapping>().is_err());
assert!(list.as_ref().downcast::<PySequence>().is_ok());
Expand Down

0 comments on commit aea20f6

Please sign in to comment.