Skip to content

Commit

Permalink
fix: update PyTryFrom for PyMapping to more accurately check for mapp…
Browse files Browse the repository at this point in the history
…ing types
  • Loading branch information
aganders3 committed Jun 24, 2022
1 parent 2ee15de commit ec41d8a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
21 changes: 19 additions & 2 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,12 @@ unsafe fn create_type_object_impl(
name: py_class_qualified_name(module_name, name)?,
basicsize: basicsize as c_int,
itemsize: 0,
flags: py_class_flags(has_traverse, is_basetype),
flags: py_class_flags(
has_traverse,
is_basetype,
#[cfg(all(Py_3_10, not(Py_LIMITED_API)))]
is_mapping,
),
slots: slots.as_mut_ptr(),
};

Expand Down Expand Up @@ -298,7 +303,11 @@ fn py_class_qualified_name(module_name: Option<&str>, class_name: &str) -> PyRes
.into_raw())
}

fn py_class_flags(is_gc: bool, is_basetype: bool) -> c_uint {
fn py_class_flags(
is_gc: bool,
is_basetype: bool,
#[cfg(all(Py_3_10, not(Py_LIMITED_API)))] is_mapping: bool,
) -> c_uint {
let mut flags = ffi::Py_TPFLAGS_DEFAULT;

if is_gc {
Expand All @@ -309,6 +318,14 @@ fn py_class_flags(is_gc: bool, is_basetype: bool) -> c_uint {
flags |= ffi::Py_TPFLAGS_BASETYPE;
}

#[cfg(all(Py_3_10, not(Py_LIMITED_API)))]
{
if is_mapping {
flags |= ffi::Py_TPFLAGS_MAPPING;
flags &= !ffi::Py_TPFLAGS_SEQUENCE;
}
}

// `c_ulong` and `c_uint` have the same size
// on some platforms (like windows)
#[allow(clippy::useless_conversion)]
Expand Down
28 changes: 26 additions & 2 deletions src/types/mapping.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use crate::err::{PyDowncastError, PyErr, PyResult};
use crate::types::{PyAny, PySequence};
use crate::types::{PyAny, PyModule, PySequence};
use crate::{ffi, AsPyPointer, IntoPyPointer, Py, PyNativeType, PyTryFrom, Python, ToPyObject};

/// Represents a reference to a Python object supporting the mapping protocol.
Expand Down Expand Up @@ -105,10 +105,34 @@ impl PyMapping {
}

impl<'v> PyTryFrom<'v> for PyMapping {
// for Python < 3.10 or if using the Py_LIMITED_API, call into python to check
// isinstance(value, collections.abc.Mapping) to determine downcastability
#[cfg(any(Py_LIMITED_API, Py_3_6, Py_3_7, Py_3_8, Py_3_9))]
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> {
let value = value.into();
let is_mapping = Python::with_gil(|py| {
let builtins = PyModule::import(py, "builtins")?;
let mapping_abc = PyModule::import(py, "collections.abc")?.getattr("Mapping")?;
builtins
.getattr("isinstance")?
.call1((value, mapping_abc))?
.extract::<bool>()
});
if is_mapping.unwrap_or(false) {
unsafe { Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value)) }
} else {
Err(PyDowncastError::new(value, "Mapping"))
}
}

// for Python >= 3.10 and not using the Py_LIMITED_API, check Py_TPFLAGS_MAPPING to determine
// downcastability
#[cfg(not(any(Py_LIMITED_API, Py_3_6, Py_3_7, Py_3_8, Py_3_9)))]
fn try_from<V: Into<&'v PyAny>>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> {
let value = value.into();
unsafe {
if ffi::PyMapping_Check(value.as_ptr()) != 0 {
let ty = ffi::Py_TYPE(value);
if ffi::PyType_HasFeature(ty, fft::Py_TPFLAGS_MAPPING) != 0 {
Ok(<PyMapping as PyTryFrom>::try_from_unchecked(value))
} else {
Err(PyDowncastError::new(value, "Mapping"))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use pyo3::types::IntoPyDict;
use pyo3::types::PyList;
use pyo3::types::PyMapping;
use pyo3::types::PySequence;
use pyo3::PyTypeInfo;

mod common;

Expand Down Expand Up @@ -116,6 +117,16 @@ fn test_delitem() {
#[test]
fn mapping_is_not_sequence() {
Python::with_gil(|py| {
// downcast to PyMapping requires isinstance(<cls>, collections.abc.Mapping) to pass, so we
// have to register the class first
PyModule::import(py, "collections.abc")
.unwrap()
.getattr("Mapping")
.unwrap()
.getattr("register")
.unwrap()
.call1((Mapping::type_object(py),))
.unwrap();
let mut index = HashMap::new();
index.insert("Foo".into(), 1);
index.insert("Bar".into(), 2);
Expand Down
18 changes: 17 additions & 1 deletion tests/test_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyList};
use pyo3::types::{IntoPyDict, PyList, PyMapping, PySequence};

use pyo3::py_run;

Expand Down Expand Up @@ -321,3 +321,19 @@ fn test_option_list_get() {
py_assert!(py, list, "list[1] == None");
py_expect_exception!(py, list, "list[2]", PyIndexError);
}

#[test]
fn sequence_is_not_mapping() {
let gil = Python::acquire_gil();
let py = gil.python();

let list = PyCell::new(
py,
OptionList {
items: vec![Some(1), None],
},
)
.unwrap();
assert!(list.as_ref().downcast::<PyMapping>().is_err());
assert!(list.as_ref().downcast::<PySequence>().is_ok());
}

0 comments on commit ec41d8a

Please sign in to comment.