From d22d62318b3ed01fee6694ee8f4f5818b38fb3bd Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 28 Aug 2021 23:10:33 +0100 Subject: [PATCH] types: add PyMapping --- CHANGELOG.md | 1 + src/ffi/abstract_.rs | 3 + src/types/dict.rs | 28 +++++ src/types/mapping.rs | 259 +++++++++++++++++++++++++++++++++++++++++++ src/types/mod.rs | 2 + 5 files changed, 293 insertions(+) create mode 100644 src/types/mapping.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index ebf8bc2a6ab..6f52a3c6d10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `PyAny::py` as a convenience for `PyNativeType::py`. [#1751](https://github.com/PyO3/pyo3/pull/1751) - Add implementation of `std::ops::Index` for `PyList`, `PyTuple` and `PySequence`. [#1825](https://github.com/PyO3/pyo3/pull/1825) - Add range indexing implementations of `std::ops::Index` for `PyList`, `PyTuple` and `PySequence`. [#1829](https://github.com/PyO3/pyo3/pull/1829) +- Add `PyMapping` type to represent the Python mapping protocol. [#1844](https://github.com/PyO3/pyo3/pull/1844) - Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849) - Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860) diff --git a/src/ffi/abstract_.rs b/src/ffi/abstract_.rs index bb38f8ae367..2efe1d3c5d7 100644 --- a/src/ffi/abstract_.rs +++ b/src/ffi/abstract_.rs @@ -78,7 +78,9 @@ extern "C" { pub fn PyObject_GetItem(o: *mut PyObject, key: *mut PyObject) -> *mut PyObject; #[cfg_attr(PyPy, link_name = "PyPyObject_SetItem")] pub fn PyObject_SetItem(o: *mut PyObject, key: *mut PyObject, v: *mut PyObject) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyObject_DelItemString")] pub fn PyObject_DelItemString(o: *mut PyObject, key: *const c_char) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyObject_DelItem")] pub fn PyObject_DelItem(o: *mut PyObject, key: *mut PyObject) -> c_int; } @@ -300,6 +302,7 @@ pub unsafe fn PyMapping_DelItem(o: *mut PyObject, key: *mut PyObject) -> c_int { extern "C" { #[cfg_attr(PyPy, link_name = "PyPyMapping_HasKeyString")] pub fn PyMapping_HasKeyString(o: *mut PyObject, key: *const c_char) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyMapping_HasKey")] pub fn PyMapping_HasKey(o: *mut PyObject, key: *mut PyObject) -> c_int; #[cfg_attr(PyPy, link_name = "PyPyMapping_Keys")] pub fn PyMapping_Keys(o: *mut PyObject) -> *mut PyObject; diff --git a/src/types/dict.rs b/src/types/dict.rs index d4297584be5..f3c3d7c2003 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -12,6 +12,8 @@ use std::collections::{BTreeMap, HashMap}; use std::ptr::NonNull; use std::{cmp, collections, hash}; +use super::PyMapping; + /// Represents a Python `dict`. #[repr(transparent)] pub struct PyDict(PyAny); @@ -178,6 +180,11 @@ impl PyDict { pos: 0, } } + + /// Returns `self` cast as a `PyMapping`. + pub fn as_mapping(&self) -> &PyMapping { + unsafe { PyMapping::try_from_unchecked(self) } + } } pub struct PyDictIterator<'py> { @@ -850,4 +857,25 @@ mod tests { assert_eq!(py_map.get_item("b").unwrap().extract::().unwrap(), 2); }); } + + #[test] + fn dict_as_mapping() { + Python::with_gil(|py| { + let mut map = HashMap::::new(); + map.insert(1, 1); + + let py_map = map.into_py_dict(py); + + assert_eq!(py_map.as_mapping().len().unwrap(), 1); + assert_eq!( + py_map + .as_mapping() + .get_item(1) + .unwrap() + .extract::() + .unwrap(), + 1 + ); + }); + } } diff --git a/src/types/mapping.rs b/src/types/mapping.rs new file mode 100644 index 00000000000..ca907f67a4c --- /dev/null +++ b/src/types/mapping.rs @@ -0,0 +1,259 @@ +// Copyright (c) 2017-present PyO3 Project and Contributors + +use crate::err::{PyDowncastError, PyErr, PyResult}; +use crate::types::{PyAny, PySequence}; +use crate::AsPyPointer; +use crate::{ffi, ToPyObject}; +use crate::{PyTryFrom, ToBorrowedObject}; + +/// Represents a reference to a Python object supporting the mapping protocol. +#[repr(transparent)] +pub struct PyMapping(PyAny); +pyobject_native_type_named!(PyMapping); +pyobject_native_type_extract!(PyMapping); + +impl PyMapping { + /// Returns the number of objects in the mapping. + /// + /// This is equivalent to the Python expression `len(self)`. + #[inline] + pub fn len(&self) -> PyResult { + let v = unsafe { ffi::PyMapping_Size(self.as_ptr()) }; + if v == -1 { + Err(PyErr::api_call_failed(self.py())) + } else { + Ok(v as usize) + } + } + + /// Returns whether the mapping is empty. + #[inline] + pub fn is_empty(&self) -> PyResult { + self.len().map(|l| l == 0) + } + + /// Gets the item in self with key `key`. + /// + /// Returns an `Err` if the item with specified key is not found, usually `KeyError`. + /// + /// This is equivalent to the Python expression `self[key]`. + #[inline] + pub fn get_item(&self, key: K) -> PyResult<&PyAny> + where + K: ToBorrowedObject, + { + PyAny::get_item(self, key) + } + + /// Sets the item in self with key `key`. + /// + /// This is equivalent to the Python expression `self[key] = value`. + #[inline] + pub fn set_item(&self, key: K, value: V) -> PyResult<()> + where + K: ToPyObject, + V: ToPyObject, + { + PyAny::set_item(self, key, value) + } + + /// Deletes the item with key `key`. + /// + /// This is equivalent to the Python statement `del self[key]`. + #[inline] + pub fn del_item(&self, key: K) -> PyResult<()> + where + K: ToBorrowedObject, + { + PyAny::del_item(self, key) + } + + /// Returns a sequence containing all keys in the mapping. + #[inline] + pub fn keys(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Keys(self.as_ptr())) + } + } + + /// Returns a sequence containing all values in the mapping. + #[inline] + pub fn values(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Values(self.as_ptr())) + } + } + + /// Returns a sequence of tuples of all (key, value) pairs in the mapping. + #[inline] + pub fn items(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Items(self.as_ptr())) + } + } +} + +impl<'v> PyTryFrom<'v> for PyMapping { + fn try_from>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { + let value = value.into(); + unsafe { + if ffi::PyMapping_Check(value.as_ptr()) != 0 { + Ok(::try_from_unchecked(value)) + } else { + Err(PyDowncastError::new(value, "Mapping")) + } + } + } + + #[inline] + fn try_from_exact>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { + ::try_from(value) + } + + #[inline] + unsafe fn try_from_unchecked>(value: V) -> &'v PyMapping { + let ptr = value.into() as *const _ as *const PyMapping; + &*ptr + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{exceptions::PyKeyError, types::PyTuple, Python}; + + use super::*; + + #[test] + fn test_len() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!(0, mapping.len().unwrap()); + assert!(mapping.is_empty().unwrap()); + + v.insert(7, 32); + let ob = v.to_object(py); + let mapping2 = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!(1, mapping2.len().unwrap()); + assert!(!mapping2.is_empty().unwrap()); + }); + } + + #[test] + fn test_get_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!( + 32, + mapping.get_item(7i32).unwrap().extract::().unwrap() + ); + assert!(mapping + .get_item(8i32) + .unwrap_err() + .is_instance::(py)); + }); + } + + #[test] + fn test_set_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert!(mapping.set_item(7i32, 42i32).is_ok()); // change + assert!(mapping.set_item(8i32, 123i32).is_ok()); // insert + assert_eq!( + 42i32, + mapping.get_item(7i32).unwrap().extract::().unwrap() + ); + assert_eq!( + 123i32, + mapping.get_item(8i32).unwrap().extract::().unwrap() + ); + }); + } + + #[test] + fn test_del_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert!(mapping.del_item(7i32).is_ok()); + assert_eq!(0, mapping.len().unwrap()); + assert!(mapping + .get_item(7i32) + .unwrap_err() + .is_instance::(py)); + }); + } + + #[test] + fn test_items() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut key_sum = 0; + let mut value_sum = 0; + for el in mapping.items().unwrap().iter().unwrap() { + let tuple = el.unwrap().cast_as::().unwrap(); + key_sum += tuple.get_item(0).unwrap().extract::().unwrap(); + value_sum += tuple.get_item(1).unwrap().extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + assert_eq!(32 + 42 + 123, value_sum); + }); + } + + #[test] + fn test_keys() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut key_sum = 0; + for el in mapping.keys().unwrap().iter().unwrap() { + key_sum += el.unwrap().extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + }); + } + + #[test] + fn test_values() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut values_sum = 0; + for el in mapping.values().unwrap().iter().unwrap() { + values_sum += el.unwrap().extract::().unwrap(); + } + assert_eq!(32 + 42 + 123, values_sum); + }); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index b5cb6d91423..ae326830e34 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -17,6 +17,7 @@ pub use self::floatob::PyFloat; pub use self::function::{PyCFunction, PyFunction}; pub use self::iterator::PyIterator; pub use self::list::PyList; +pub use self::mapping::PyMapping; pub use self::module::PyModule; pub use self::num::PyLong; pub use self::num::PyLong as PyInt; @@ -224,6 +225,7 @@ mod floatob; mod function; mod iterator; mod list; +mod mapping; mod module; mod num; mod sequence;