Skip to content

Commit

Permalink
Protect iterators against concurrent modification (#2380)
Browse files Browse the repository at this point in the history
  • Loading branch information
mejrs committed May 31, 2022
1 parent f84c740 commit 4f9d3d7
Show file tree
Hide file tree
Showing 5 changed files with 504 additions and 198 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow `#[classattr]` methods to be fallible. [#2385](https://github.com/PyO3/pyo3/pull/2385)
- Prevent multiple `#[pymethods]` with the same name for a single `#[pyclass]`. [#2399](https://github.com/PyO3/pyo3/pull/2399)
- Fixup `lib_name` when using `PYO3_CONFIG_FILE`. [#2404](https://github.com/PyO3/pyo3/pull/2404)
- Iterators over `PySet` and `PyDict` will now panic if the underlying collection is mutated during the iteration. [#2380](https://github.com/PyO3/pyo3/pull/2380)

### Fixed

Expand Down
189 changes: 158 additions & 31 deletions src/types/dict.rs
@@ -1,6 +1,8 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use super::PyMapping;
use crate::err::{self, PyErr, PyResult};
use crate::ffi::Py_ssize_t;
use crate::types::{PyAny, PyList};
#[cfg(not(PyPy))]
use crate::IntoPyPointer;
Expand All @@ -9,8 +11,6 @@ use std::collections::{BTreeMap, HashMap};
use std::ptr::NonNull;
use std::{cmp, collections, hash};

use super::PyMapping;

/// Represents a Python `dict`.
#[repr(transparent)]
pub struct PyDict(PyAny);
Expand Down Expand Up @@ -102,7 +102,19 @@ impl PyDict {
///
/// This is equivalent to the Python expression `len(self)`.
pub fn len(&self) -> usize {
unsafe { ffi::PyDict_Size(self.as_ptr()) as usize }
self._len() as usize
}

fn _len(&self) -> Py_ssize_t {
#[cfg(any(not(Py_3_8), PyPy, Py_LIMITED_API))]
unsafe {
ffi::PyDict_Size(self.as_ptr())
}

#[cfg(all(Py_3_8, not(PyPy), not(Py_LIMITED_API)))]
unsafe {
(*self.as_ptr().cast::<ffi::PyDictObject>()).ma_used
}
}

/// Checks if the dict is empty, i.e. `len(self) == 0`.
Expand Down Expand Up @@ -213,13 +225,13 @@ impl PyDict {

/// Returns an iterator of `(key, value)` pairs in this dictionary.
///
/// Note that it's unsafe to use when the dictionary might be changed by
/// other code.
/// # Panics
///
/// If PyO3 detects that the dictionary is mutated during iteration, it will panic.
/// It is allowed to modify values as you iterate over the dictionary, but only
/// so long as the set of keys does not change.
pub fn iter(&self) -> PyDictIterator<'_> {
PyDictIterator {
dict: self.as_ref(),
pos: 0,
}
IntoIterator::into_iter(self)
}

/// Returns `self` cast as a `PyMapping`.
Expand All @@ -229,38 +241,54 @@ impl PyDict {
}

pub struct PyDictIterator<'py> {
dict: &'py PyAny,
pos: isize,
dict: &'py PyDict,
ppos: ffi::Py_ssize_t,
di_used: ffi::Py_ssize_t,
len: ffi::Py_ssize_t,
}

impl<'py> Iterator for PyDictIterator<'py> {
type Item = (&'py PyAny, &'py PyAny);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
unsafe {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();
if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.pos, &mut key, &mut value) != 0 {
let py = self.dict.py();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
}
let ma_used = self.dict._len();

// These checks are similar to what CPython does.
//
// If the dimension of the dict changes e.g. key-value pairs are removed
// or added during iteration, this will panic next time when `next` is called
if self.di_used != ma_used {
self.di_used = -1;
panic!("dictionary changed size during iteration");
};

// If the dict is changed in such a way that the length remains constant
// then this will panic at the end of iteration - similar to this:
//
// d = {"a":1, "b":2, "c": 3}
//
// for k, v in d.items():
// d[f"{k}_"] = 4
// del d[k]
// print(k)
//
if self.len == -1 {
self.di_used = -1;
panic!("dictionary keys changed during iteration");
};

let ret = unsafe { self.next_unchecked() };
if ret.is_some() {
self.len -= 1
}
ret
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.dict.len().unwrap_or_default();
(
len.saturating_sub(self.pos as usize),
Some(len.saturating_sub(self.pos as usize)),
)
let len = self.len as usize;
(len, Some(len))
}
}

Expand All @@ -269,7 +297,34 @@ impl<'a> std::iter::IntoIterator for &'a PyDict {
type IntoIter = PyDictIterator<'a>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
PyDictIterator {
dict: self,
ppos: 0,
di_used: self._len(),
len: self._len(),
}
}
}

impl<'py> PyDictIterator<'py> {
/// Advances the iterator without checking for concurrent modification.
///
/// See [`PyDict_Next`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Next)
/// for more information.
unsafe fn next_unchecked(&mut self) -> Option<(&'py PyAny, &'py PyAny)> {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();

if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) != 0 {
let py = self.dict.py();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
}
}
}

Expand Down Expand Up @@ -660,6 +715,74 @@ mod tests {
});
}

#[test]
fn test_iter_value_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
v.insert(7, 32);
v.insert(8, 42);
v.insert(9, 123);

let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (key, value) in dict.iter() {
dict.set_item(key, value.extract::<i32>().unwrap() + 7)
.unwrap();
}
});
}

#[test]
#[should_panic]
fn test_iter_key_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
for i in 0..10 {
v.insert(i * 2, i * 2);
}
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (i, (key, value)) in dict.iter().enumerate() {
let key = key.extract::<i32>().unwrap();
let value = value.extract::<i32>().unwrap();

dict.set_item(key + 1, value + 1).unwrap();

if i > 1000 {
// avoid this test just running out of memory if it fails
break;
};
}
});
}

#[test]
#[should_panic]
fn test_iter_key_mutated_constant_len() {
Python::with_gil(|py| {
let mut v = HashMap::new();
for i in 0..10 {
v.insert(i * 2, i * 2);
}
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (i, (key, value)) in dict.iter().enumerate() {
let key = key.extract::<i32>().unwrap();
let value = value.extract::<i32>().unwrap();
dict.del_item(key).unwrap();
dict.set_item(key + 1, value + 1).unwrap();

if i > 1000 {
// avoid this test just running out of memory if it fails
break;
};
}
});
}

#[test]
fn test_iter_size_hint() {
Python::with_gil(|py| {
Expand All @@ -675,10 +798,14 @@ mod tests {
iter.next();
assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1)));

// Exhust iterator.
// Exhaust iterator.
for _ in &mut iter {}

assert_eq!(iter.size_hint(), (0, Some(0)));

assert!(iter.next().is_none());

assert_eq!(iter.size_hint(), (0, Some(0)));
});
}

Expand Down

0 comments on commit 4f9d3d7

Please sign in to comment.