From 4f9d3d73069e34ed9c3e35e8ba18f845ddca9b39 Mon Sep 17 00:00:00 2001 From: Bruno Kolenbrander <59372212+mejrs@users.noreply.github.com> Date: Tue, 31 May 2022 21:13:04 +0200 Subject: [PATCH] Protect iterators against concurrent modification (#2380) --- CHANGELOG.md | 1 + src/types/dict.rs | 189 +++++++++++++++++++++++----- src/types/frozenset.rs | 194 ++++++++++++++++++++++++++++ src/types/mod.rs | 40 +++++- src/types/set.rs | 278 +++++++++++++++++------------------------ 5 files changed, 504 insertions(+), 198 deletions(-) create mode 100644 src/types/frozenset.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index aa8da90cecc..0e1716ead39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allow `#[classattr]` methods to be fallible. [#2385](https://github.com/PyO3/pyo3/pull/2385) - Prevent multiple `#[pymethods]` with the same name for a single `#[pyclass]`. [#2399](https://github.com/PyO3/pyo3/pull/2399) - Fixup `lib_name` when using `PYO3_CONFIG_FILE`. [#2404](https://github.com/PyO3/pyo3/pull/2404) +- Iterators over `PySet` and `PyDict` will now panic if the underlying collection is mutated during the iteration. [#2380](https://github.com/PyO3/pyo3/pull/2380) ### Fixed diff --git a/src/types/dict.rs b/src/types/dict.rs index 18a7d4b4f09..bd75d89722b 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -1,6 +1,8 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use super::PyMapping; use crate::err::{self, PyErr, PyResult}; +use crate::ffi::Py_ssize_t; use crate::types::{PyAny, PyList}; #[cfg(not(PyPy))] use crate::IntoPyPointer; @@ -9,8 +11,6 @@ use std::collections::{BTreeMap, HashMap}; use std::ptr::NonNull; use std::{cmp, collections, hash}; -use super::PyMapping; - /// Represents a Python `dict`. #[repr(transparent)] pub struct PyDict(PyAny); @@ -102,7 +102,19 @@ impl PyDict { /// /// This is equivalent to the Python expression `len(self)`. pub fn len(&self) -> usize { - unsafe { ffi::PyDict_Size(self.as_ptr()) as usize } + self._len() as usize + } + + fn _len(&self) -> Py_ssize_t { + #[cfg(any(not(Py_3_8), PyPy, Py_LIMITED_API))] + unsafe { + ffi::PyDict_Size(self.as_ptr()) + } + + #[cfg(all(Py_3_8, not(PyPy), not(Py_LIMITED_API)))] + unsafe { + (*self.as_ptr().cast::()).ma_used + } } /// Checks if the dict is empty, i.e. `len(self) == 0`. @@ -213,13 +225,13 @@ impl PyDict { /// Returns an iterator of `(key, value)` pairs in this dictionary. /// - /// Note that it's unsafe to use when the dictionary might be changed by - /// other code. + /// # Panics + /// + /// If PyO3 detects that the dictionary is mutated during iteration, it will panic. + /// It is allowed to modify values as you iterate over the dictionary, but only + /// so long as the set of keys does not change. pub fn iter(&self) -> PyDictIterator<'_> { - PyDictIterator { - dict: self.as_ref(), - pos: 0, - } + IntoIterator::into_iter(self) } /// Returns `self` cast as a `PyMapping`. @@ -229,8 +241,10 @@ impl PyDict { } pub struct PyDictIterator<'py> { - dict: &'py PyAny, - pos: isize, + dict: &'py PyDict, + ppos: ffi::Py_ssize_t, + di_used: ffi::Py_ssize_t, + len: ffi::Py_ssize_t, } impl<'py> Iterator for PyDictIterator<'py> { @@ -238,29 +252,43 @@ impl<'py> Iterator for PyDictIterator<'py> { #[inline] fn next(&mut self) -> Option { - unsafe { - let mut key: *mut ffi::PyObject = std::ptr::null_mut(); - let mut value: *mut ffi::PyObject = std::ptr::null_mut(); - if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.pos, &mut key, &mut value) != 0 { - let py = self.dict.py(); - // PyDict_Next returns borrowed values; for safety must make them owned (see #890) - Some(( - py.from_owned_ptr(ffi::_Py_NewRef(key)), - py.from_owned_ptr(ffi::_Py_NewRef(value)), - )) - } else { - None - } + let ma_used = self.dict._len(); + + // These checks are similar to what CPython does. + // + // If the dimension of the dict changes e.g. key-value pairs are removed + // or added during iteration, this will panic next time when `next` is called + if self.di_used != ma_used { + self.di_used = -1; + panic!("dictionary changed size during iteration"); + }; + + // If the dict is changed in such a way that the length remains constant + // then this will panic at the end of iteration - similar to this: + // + // d = {"a":1, "b":2, "c": 3} + // + // for k, v in d.items(): + // d[f"{k}_"] = 4 + // del d[k] + // print(k) + // + if self.len == -1 { + self.di_used = -1; + panic!("dictionary keys changed during iteration"); + }; + + let ret = unsafe { self.next_unchecked() }; + if ret.is_some() { + self.len -= 1 } + ret } #[inline] fn size_hint(&self) -> (usize, Option) { - let len = self.dict.len().unwrap_or_default(); - ( - len.saturating_sub(self.pos as usize), - Some(len.saturating_sub(self.pos as usize)), - ) + let len = self.len as usize; + (len, Some(len)) } } @@ -269,7 +297,34 @@ impl<'a> std::iter::IntoIterator for &'a PyDict { type IntoIter = PyDictIterator<'a>; fn into_iter(self) -> Self::IntoIter { - self.iter() + PyDictIterator { + dict: self, + ppos: 0, + di_used: self._len(), + len: self._len(), + } + } +} + +impl<'py> PyDictIterator<'py> { + /// Advances the iterator without checking for concurrent modification. + /// + /// See [`PyDict_Next`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Next) + /// for more information. + unsafe fn next_unchecked(&mut self) -> Option<(&'py PyAny, &'py PyAny)> { + let mut key: *mut ffi::PyObject = std::ptr::null_mut(); + let mut value: *mut ffi::PyObject = std::ptr::null_mut(); + + if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) != 0 { + let py = self.dict.py(); + // PyDict_Next returns borrowed values; for safety must make them owned (see #890) + Some(( + py.from_owned_ptr(ffi::_Py_NewRef(key)), + py.from_owned_ptr(ffi::_Py_NewRef(value)), + )) + } else { + None + } } } @@ -660,6 +715,74 @@ mod tests { }); } + #[test] + fn test_iter_value_mutated() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + + let ob = v.to_object(py); + let dict = ::try_from(ob.as_ref(py)).unwrap(); + + for (key, value) in dict.iter() { + dict.set_item(key, value.extract::().unwrap() + 7) + .unwrap(); + } + }); + } + + #[test] + #[should_panic] + fn test_iter_key_mutated() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + for i in 0..10 { + v.insert(i * 2, i * 2); + } + let ob = v.to_object(py); + let dict = ::try_from(ob.as_ref(py)).unwrap(); + + for (i, (key, value)) in dict.iter().enumerate() { + let key = key.extract::().unwrap(); + let value = value.extract::().unwrap(); + + dict.set_item(key + 1, value + 1).unwrap(); + + if i > 1000 { + // avoid this test just running out of memory if it fails + break; + }; + } + }); + } + + #[test] + #[should_panic] + fn test_iter_key_mutated_constant_len() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + for i in 0..10 { + v.insert(i * 2, i * 2); + } + let ob = v.to_object(py); + let dict = ::try_from(ob.as_ref(py)).unwrap(); + + for (i, (key, value)) in dict.iter().enumerate() { + let key = key.extract::().unwrap(); + let value = value.extract::().unwrap(); + dict.del_item(key).unwrap(); + dict.set_item(key + 1, value + 1).unwrap(); + + if i > 1000 { + // avoid this test just running out of memory if it fails + break; + }; + } + }); + } + #[test] fn test_iter_size_hint() { Python::with_gil(|py| { @@ -675,10 +798,14 @@ mod tests { iter.next(); assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1))); - // Exhust iterator. + // Exhaust iterator. for _ in &mut iter {} assert_eq!(iter.size_hint(), (0, Some(0))); + + assert!(iter.next().is_none()); + + assert_eq!(iter.size_hint(), (0, Some(0))); }); } diff --git a/src/types/frozenset.rs b/src/types/frozenset.rs new file mode 100644 index 00000000000..d85842940c4 --- /dev/null +++ b/src/types/frozenset.rs @@ -0,0 +1,194 @@ +// Copyright (c) 2017-present PyO3 Project and Contributors +// + +use crate::err::{PyErr, PyResult}; +#[cfg(Py_LIMITED_API)] +use crate::types::PyIterator; +use crate::{ffi, AsPyPointer, PyAny, Python, ToPyObject}; + +use std::ptr; + +/// Represents a Python `frozenset` +#[repr(transparent)] +pub struct PyFrozenSet(PyAny); + +pyobject_native_type!( + PyFrozenSet, + ffi::PySetObject, + ffi::PyFrozenSet_Type, + #checkfunction=ffi::PyFrozenSet_Check +); + +impl PyFrozenSet { + /// Creates a new frozenset. + /// + /// May panic when running out of memory. + pub fn new<'p, T: ToPyObject>(py: Python<'p>, elements: &[T]) -> PyResult<&'p PyFrozenSet> { + let list = elements.to_object(py); + unsafe { py.from_owned_ptr_or_err(ffi::PyFrozenSet_New(list.as_ptr())) } + } + + /// Creates a new empty frozen set + pub fn empty(py: Python<'_>) -> PyResult<&PyFrozenSet> { + unsafe { py.from_owned_ptr_or_err(ffi::PyFrozenSet_New(ptr::null_mut())) } + } + + /// Return the number of items in the set. + /// This is equivalent to len(p) on a set. + #[inline] + pub fn len(&self) -> usize { + unsafe { ffi::PySet_Size(self.as_ptr()) as usize } + } + + /// Check if set is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Determine if the set contains the specified key. + /// This is equivalent to the Python expression `key in self`. + pub fn contains(&self, key: K) -> PyResult + where + K: ToPyObject, + { + unsafe { + match ffi::PySet_Contains(self.as_ptr(), key.to_object(self.py()).as_ptr()) { + 1 => Ok(true), + 0 => Ok(false), + _ => Err(PyErr::fetch(self.py())), + } + } + } + + /// Returns an iterator of values in this frozen set. + pub fn iter(&self) -> PyFrozenSetIterator<'_> { + IntoIterator::into_iter(self) + } +} + +#[cfg(Py_LIMITED_API)] +mod impl_ { + use super::*; + + impl<'a> std::iter::IntoIterator for &'a PyFrozenSet { + type Item = &'a PyAny; + type IntoIter = PyFrozenSetIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + PyFrozenSetIterator { + it: PyIterator::from_object(self.py(), self).unwrap(), + } + } + } + + pub struct PyFrozenSetIterator<'p> { + it: &'p PyIterator, + } + + impl<'py> Iterator for PyFrozenSetIterator<'py> { + type Item = &'py super::PyAny; + + #[inline] + fn next(&mut self) -> Option { + self.it.next().map(Result::unwrap) + } + } +} + +#[cfg(not(Py_LIMITED_API))] +mod impl_ { + use super::*; + + impl<'a> std::iter::IntoIterator for &'a PyFrozenSet { + type Item = &'a PyAny; + type IntoIter = PyFrozenSetIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + PyFrozenSetIterator { set: self, pos: 0 } + } + } + + pub struct PyFrozenSetIterator<'py> { + set: &'py PyAny, + pos: ffi::Py_ssize_t, + } + + impl<'py> Iterator for PyFrozenSetIterator<'py> { + type Item = &'py PyAny; + + #[inline] + fn next(&mut self) -> Option { + unsafe { + let mut key: *mut ffi::PyObject = std::ptr::null_mut(); + let mut hash: ffi::Py_hash_t = 0; + if ffi::_PySet_NextEntry(self.set.as_ptr(), &mut self.pos, &mut key, &mut hash) != 0 + { + // _PySet_NextEntry returns borrowed object; for safety must make owned (see #890) + Some(self.set.py().from_owned_ptr(ffi::_Py_NewRef(key))) + } else { + None + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.set.len().unwrap_or_default(); + ( + len.saturating_sub(self.pos as usize), + Some(len.saturating_sub(self.pos as usize)), + ) + } + } +} + +pub use impl_::*; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_frozenset_new_and_len() { + Python::with_gil(|py| { + let set = PyFrozenSet::new(py, &[1]).unwrap(); + assert_eq!(1, set.len()); + + let v = vec![1]; + assert!(PyFrozenSet::new(py, &[v]).is_err()); + }); + } + + #[test] + fn test_frozenset_empty() { + Python::with_gil(|py| { + let set = PyFrozenSet::empty(py).unwrap(); + assert_eq!(0, set.len()); + }); + } + + #[test] + fn test_frozenset_contains() { + Python::with_gil(|py| { + let set = PyFrozenSet::new(py, &[1]).unwrap(); + assert!(set.contains(1).unwrap()); + }); + } + + #[test] + fn test_frozenset_iter() { + Python::with_gil(|py| { + let set = PyFrozenSet::new(py, &[1]).unwrap(); + + // iter method + for el in set.iter() { + assert_eq!(1i32, el.extract::().unwrap()); + } + + // intoiterator iteration + for el in set { + assert_eq!(1i32, el.extract::().unwrap()); + } + }); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index f0419cbe177..861a4a3ebfe 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -15,6 +15,7 @@ pub use self::datetime::{ }; pub use self::dict::{IntoPyDict, PyDict}; pub use self::floatob::PyFloat; +pub use self::frozenset::PyFrozenSet; pub use self::function::{PyCFunction, PyFunction}; pub use self::iterator::PyIterator; pub use self::list::PyList; @@ -23,7 +24,7 @@ pub use self::module::PyModule; pub use self::num::PyLong; pub use self::num::PyLong as PyInt; pub use self::sequence::PySequence; -pub use self::set::{PyFrozenSet, PySet}; +pub use self::set::PySet; pub use self::slice::{PySlice, PySliceIndices}; #[cfg(all(not(Py_LIMITED_API), target_endian = "little"))] pub use self::string::PyStringData; @@ -32,6 +33,42 @@ pub use self::traceback::PyTraceback; pub use self::tuple::PyTuple; pub use self::typeobject::PyType; +/// Iteration over Python collections. +/// +/// When working with a Python collection, one approach is to convert it to a Rust collection such +/// as `Vec` or `HashMap`. However this is a relatively expensive operation. If you just want to +/// visit all their items, consider iterating over the collections directly: +/// +/// # Examples +/// +/// ```rust +/// use pyo3::prelude::*; +/// use pyo3::types::PyDict; +/// +/// # pub fn main() -> PyResult<()> { +/// Python::with_gil(|py| { +/// let dict: &PyDict = py.eval("{'a':'b', 'c':'d'}", None, None)?.cast_as()?; +/// +/// for (key, value) in dict { +/// println!("key: {}, value: {}", key, value); +/// } +/// +/// Ok(()) +/// }) +/// # } +/// ``` +/// +/// If PyO3 detects that the collection is mutated during iteration, it will panic. +/// +/// These iterators use Python's C-API directly. However in certain cases, like when compiling for +/// the Limited API and PyPy, the underlying structures are opaque and that may not be possible. +/// In these cases the iterators are implemented by forwarding to [`PyIterator`]. +pub mod iter { + pub use super::dict::PyDictIterator; + pub use super::frozenset::PyFrozenSetIterator; + pub use super::set::PySetIterator; +} + // Implementations core to all native types #[doc(hidden)] #[macro_export] @@ -225,6 +262,7 @@ mod complex; mod datetime; mod dict; mod floatob; +mod frozenset; mod function; mod iterator; mod list; diff --git a/src/types/set.rs b/src/types/set.rs index d8445c9107b..7390807736c 100644 --- a/src/types/set.rs +++ b/src/types/set.rs @@ -13,17 +13,7 @@ use std::{collections, hash, ptr}; #[repr(transparent)] pub struct PySet(PyAny); -/// Represents a Python `frozenset` -#[repr(transparent)] -pub struct PyFrozenSet(PyAny); - pyobject_native_type!(PySet, ffi::PySetObject, ffi::PySet_Type, #checkfunction=ffi::PySet_Check); -pyobject_native_type!( - PyFrozenSet, - ffi::PySetObject, - ffi::PyFrozenSet_Type, - #checkfunction=ffi::PyFrozenSet_Check -); impl PySet { /// Creates a new set with elements from the given slice. @@ -111,86 +101,118 @@ impl PySet { /// Returns an iterator of values in this set. /// - /// Note that it can be unsafe to use when the set might be changed by other code. + /// # Panics + /// + /// If PyO3 detects that the set is mutated during iteration, it will panic. pub fn iter(&self) -> PySetIterator<'_> { - PySetIterator::new(self) + IntoIterator::into_iter(self) } } #[cfg(Py_LIMITED_API)] -pub struct PySetIterator<'p> { - it: &'p PyIterator, -} - -#[cfg(Py_LIMITED_API)] -impl PySetIterator<'_> { - fn new(set: &PyAny) -> PySetIterator<'_> { - PySetIterator { - it: PyIterator::from_object(set.py(), set).unwrap(), +mod impl_ { + use super::*; + + impl<'a> std::iter::IntoIterator for &'a PySet { + type Item = &'a PyAny; + type IntoIter = PySetIterator<'a>; + + /// Returns an iterator of values in this set. + /// + /// # Panics + /// + /// If PyO3 detects that the set is mutated during iteration, it will panic. + fn into_iter(self) -> Self::IntoIter { + PySetIterator { + it: PyIterator::from_object(self.py(), self).unwrap(), + } } } -} - -#[cfg(Py_LIMITED_API)] -impl<'py> Iterator for PySetIterator<'py> { - type Item = &'py super::PyAny; - #[inline] - fn next(&mut self) -> Option { - self.it.next().map(|p| p.unwrap()) + pub struct PySetIterator<'p> { + it: &'p PyIterator, } -} -#[cfg(not(Py_LIMITED_API))] -pub struct PySetIterator<'py> { - set: &'py super::PyAny, - pos: isize, -} + impl<'py> Iterator for PySetIterator<'py> { + type Item = &'py super::PyAny; -#[cfg(not(Py_LIMITED_API))] -impl PySetIterator<'_> { - fn new(set: &PyAny) -> PySetIterator<'_> { - PySetIterator { set, pos: 0 } + /// Advances the iterator and returns the next value. + /// + /// # Panics + /// + /// If PyO3 detects that the set is mutated during iteration, it will panic. + #[inline] + fn next(&mut self) -> Option { + self.it.next().map(Result::unwrap) + } } } #[cfg(not(Py_LIMITED_API))] -impl<'py> Iterator for PySetIterator<'py> { - type Item = &'py super::PyAny; - - #[inline] - fn next(&mut self) -> Option { - unsafe { - let mut key: *mut ffi::PyObject = std::ptr::null_mut(); - let mut hash: ffi::Py_hash_t = 0; - if ffi::_PySet_NextEntry(self.set.as_ptr(), &mut self.pos, &mut key, &mut hash) != 0 { - // _PySet_NextEntry returns borrowed object; for safety must make owned (see #890) - Some(self.set.py().from_owned_ptr(ffi::_Py_NewRef(key))) - } else { - None +mod impl_ { + use super::*; + pub struct PySetIterator<'py> { + set: &'py super::PyAny, + pos: ffi::Py_ssize_t, + used: ffi::Py_ssize_t, + } + + impl<'a> std::iter::IntoIterator for &'a PySet { + type Item = &'a PyAny; + type IntoIter = PySetIterator<'a>; + /// Returns an iterator of values in this set. + /// + /// # Panics + /// + /// If PyO3 detects that the set is mutated during iteration, it will panic. + fn into_iter(self) -> Self::IntoIter { + PySetIterator { + set: self, + pos: 0, + used: unsafe { ffi::PySet_Size(self.as_ptr()) }, } } } - #[inline] - fn size_hint(&self) -> (usize, Option) { - let len = self.set.len().unwrap_or_default(); - ( - len.saturating_sub(self.pos as usize), - Some(len.saturating_sub(self.pos as usize)), - ) - } -} - -impl<'a> std::iter::IntoIterator for &'a PySet { - type Item = &'a PyAny; - type IntoIter = PySetIterator<'a>; + impl<'py> Iterator for PySetIterator<'py> { + type Item = &'py super::PyAny; + + /// Advances the iterator and returns the next value. + /// + /// # Panics + /// + /// If PyO3 detects that the set is mutated during iteration, it will panic. + #[inline] + fn next(&mut self) -> Option { + unsafe { + let len = ffi::PySet_Size(self.set.as_ptr()); + assert_eq!(self.used, len, "Set changed size during iteration"); + + let mut key: *mut ffi::PyObject = std::ptr::null_mut(); + let mut hash: ffi::Py_hash_t = 0; + if ffi::_PySet_NextEntry(self.set.as_ptr(), &mut self.pos, &mut key, &mut hash) != 0 + { + // _PySet_NextEntry returns borrowed object; for safety must make owned (see #890) + Some(self.set.py().from_owned_ptr(ffi::_Py_NewRef(key))) + } else { + None + } + } + } - fn into_iter(self) -> Self::IntoIter { - self.iter() + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.set.len().unwrap_or_default(); + ( + len.saturating_sub(self.pos as usize), + Some(len.saturating_sub(self.pos as usize)), + ) + } } } +pub use impl_::*; + impl ToPyObject for collections::HashSet where T: hash::Hash + Eq + ToPyObject, @@ -274,67 +296,9 @@ where } } -impl PyFrozenSet { - /// Creates a new frozenset. - /// - /// May panic when running out of memory. - pub fn new<'p, T: ToPyObject>(py: Python<'p>, elements: &[T]) -> PyResult<&'p PyFrozenSet> { - let list = elements.to_object(py); - unsafe { py.from_owned_ptr_or_err(ffi::PyFrozenSet_New(list.as_ptr())) } - } - - /// Creates a new empty frozen set - pub fn empty(py: Python<'_>) -> PyResult<&PyFrozenSet> { - unsafe { py.from_owned_ptr_or_err(ffi::PyFrozenSet_New(ptr::null_mut())) } - } - - /// Return the number of items in the set. - /// This is equivalent to len(p) on a set. - #[inline] - pub fn len(&self) -> usize { - unsafe { ffi::PySet_Size(self.as_ptr()) as usize } - } - - /// Check if set is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Determine if the set contains the specified key. - /// This is equivalent to the Python expression `key in self`. - pub fn contains(&self, key: K) -> PyResult - where - K: ToPyObject, - { - unsafe { - match ffi::PySet_Contains(self.as_ptr(), key.to_object(self.py()).as_ptr()) { - 1 => Ok(true), - 0 => Ok(false), - _ => Err(PyErr::fetch(self.py())), - } - } - } - - /// Returns an iterator of values in this frozen set. - /// - /// Note that it can be unsafe to use when the set might be changed by other code. - pub fn iter(&self) -> PySetIterator<'_> { - PySetIterator::new(self.as_ref()) - } -} - -impl<'a> std::iter::IntoIterator for &'a PyFrozenSet { - type Item = &'a PyAny; - type IntoIter = PySetIterator<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - #[cfg(test)] mod tests { - use super::{PyFrozenSet, PySet}; + use super::PySet; use crate::{IntoPy, PyObject, PyTryFrom, Python, ToPyObject}; use std::collections::{BTreeSet, HashSet}; @@ -441,62 +405,44 @@ mod tests { } #[test] - fn test_set_iter_size_hint() { + #[should_panic] + fn test_set_iter_mutation() { Python::with_gil(|py| { - let set = PySet::new(py, &[1]).unwrap(); - - let mut iter = set.iter(); + let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap(); - if cfg!(Py_LIMITED_API) { - assert_eq!(iter.size_hint(), (0, None)); - } else { - assert_eq!(iter.size_hint(), (1, Some(1))); - iter.next(); - assert_eq!(iter.size_hint(), (0, Some(0))); + for _ in set { + let _ = set.add(42); } }); } #[test] - fn test_frozenset_new_and_len() { - Python::with_gil(|py| { - let set = PyFrozenSet::new(py, &[1]).unwrap(); - assert_eq!(1, set.len()); - - let v = vec![1]; - assert!(PyFrozenSet::new(py, &[v]).is_err()); - }); - } - - #[test] - fn test_frozenset_empty() { + #[should_panic] + fn test_set_iter_mutation_same_len() { Python::with_gil(|py| { - let set = PyFrozenSet::empty(py).unwrap(); - assert_eq!(0, set.len()); - }); - } + let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap(); - #[test] - fn test_frozenset_contains() { - Python::with_gil(|py| { - let set = PyFrozenSet::new(py, &[1]).unwrap(); - assert!(set.contains(1).unwrap()); + for item in set { + let item: i32 = item.extract().unwrap(); + let _ = set.del_item(item); + let _ = set.add(item + 10); + } }); } #[test] - fn test_frozenset_iter() { + fn test_set_iter_size_hint() { Python::with_gil(|py| { - let set = PyFrozenSet::new(py, &[1]).unwrap(); + let set = PySet::new(py, &[1]).unwrap(); - // iter method - for el in set.iter() { - assert_eq!(1i32, el.extract::().unwrap()); - } + let mut iter = set.iter(); - // intoiterator iteration - for el in set { - assert_eq!(1i32, el.extract::().unwrap()); + if cfg!(Py_LIMITED_API) { + assert_eq!(iter.size_hint(), (0, None)); + } else { + assert_eq!(iter.size_hint(), (1, Some(1))); + iter.next(); + assert_eq!(iter.size_hint(), (0, Some(0))); } }); }