Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect iterators against concurrent modification #2380

Merged
merged 8 commits into from May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow `#[classattr]` methods to be fallible. [#2385](https://github.com/PyO3/pyo3/pull/2385)
- Prevent multiple `#[pymethods]` with the same name for a single `#[pyclass]`. [#2399](https://github.com/PyO3/pyo3/pull/2399)
- Fixup `lib_name` when using `PYO3_CONFIG_FILE`. [#2404](https://github.com/PyO3/pyo3/pull/2404)
- Iterators over `PySet` and `PyDict` will now panic if the underlying collection is mutated during the iteration. [#2380](https://github.com/PyO3/pyo3/pull/2380)

### Fixed

Expand Down
189 changes: 158 additions & 31 deletions src/types/dict.rs
@@ -1,6 +1,8 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use super::PyMapping;
use crate::err::{self, PyErr, PyResult};
use crate::ffi::Py_ssize_t;
use crate::types::{PyAny, PyList};
#[cfg(not(PyPy))]
use crate::IntoPyPointer;
Expand All @@ -9,8 +11,6 @@ use std::collections::{BTreeMap, HashMap};
use std::ptr::NonNull;
use std::{cmp, collections, hash};

use super::PyMapping;

/// Represents a Python `dict`.
#[repr(transparent)]
pub struct PyDict(PyAny);
Expand Down Expand Up @@ -102,7 +102,19 @@ impl PyDict {
///
/// This is equivalent to the Python expression `len(self)`.
pub fn len(&self) -> usize {
unsafe { ffi::PyDict_Size(self.as_ptr()) as usize }
self._len() as usize
}

fn _len(&self) -> Py_ssize_t {
#[cfg(any(not(Py_3_8), PyPy, Py_LIMITED_API))]
unsafe {
ffi::PyDict_Size(self.as_ptr())
}

#[cfg(all(Py_3_8, not(PyPy), not(Py_LIMITED_API)))]
unsafe {
(*self.as_ptr().cast::<ffi::PyDictObject>()).ma_used
}
}

/// Checks if the dict is empty, i.e. `len(self) == 0`.
Expand Down Expand Up @@ -213,13 +225,13 @@ impl PyDict {

/// Returns an iterator of `(key, value)` pairs in this dictionary.
///
/// Note that it's unsafe to use when the dictionary might be changed by
/// other code.
/// # Panics
///
/// If PyO3 detects that the dictionary is mutated during iteration, it will panic.
/// It is allowed to modify values as you iterate over the dictionary, but only
/// so long as the set of keys does not change.
pub fn iter(&self) -> PyDictIterator<'_> {
PyDictIterator {
dict: self.as_ref(),
pos: 0,
}
IntoIterator::into_iter(self)
}

/// Returns `self` cast as a `PyMapping`.
Expand All @@ -229,38 +241,54 @@ impl PyDict {
}

pub struct PyDictIterator<'py> {
dict: &'py PyAny,
pos: isize,
dict: &'py PyDict,
ppos: ffi::Py_ssize_t,
di_used: ffi::Py_ssize_t,
len: ffi::Py_ssize_t,
}

impl<'py> Iterator for PyDictIterator<'py> {
type Item = (&'py PyAny, &'py PyAny);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
unsafe {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();
if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.pos, &mut key, &mut value) != 0 {
let py = self.dict.py();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
}
let ma_used = self.dict._len();

// These checks are similar to what CPython does.
//
// If the dimension of the dict changes e.g. key-value pairs are removed
// or added during iteration, this will panic next time when `next` is called
if self.di_used != ma_used {
self.di_used = -1;
panic!("dictionary changed size during iteration");
};

// If the dict is changed in such a way that the length remains constant
// then this will panic at the end of iteration - similar to this:
//
// d = {"a":1, "b":2, "c": 3}
//
// for k, v in d.items():
// d[f"{k}_"] = 4
// del d[k]
// print(k)
//
if self.len == -1 {
self.di_used = -1;
panic!("dictionary keys changed during iteration");
};

let ret = unsafe { self.next_unchecked() };
if ret.is_some() {
self.len -= 1
}
ret
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.dict.len().unwrap_or_default();
(
len.saturating_sub(self.pos as usize),
Some(len.saturating_sub(self.pos as usize)),
)
let len = self.len as usize;
(len, Some(len))
}
}

Expand All @@ -269,7 +297,34 @@ impl<'a> std::iter::IntoIterator for &'a PyDict {
type IntoIter = PyDictIterator<'a>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
PyDictIterator {
dict: self,
ppos: 0,
di_used: self._len(),
len: self._len(),
}
}
}

impl<'py> PyDictIterator<'py> {
/// Advances the iterator without checking for concurrent modification.
///
/// See [`PyDict_Next`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Next)
/// for more information.
unsafe fn next_unchecked(&mut self) -> Option<(&'py PyAny, &'py PyAny)> {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();

if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) != 0 {
let py = self.dict.py();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
}
}
}

Expand Down Expand Up @@ -660,6 +715,74 @@ mod tests {
});
}

#[test]
fn test_iter_value_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
v.insert(7, 32);
v.insert(8, 42);
v.insert(9, 123);

let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (key, value) in dict.iter() {
dict.set_item(key, value.extract::<i32>().unwrap() + 7)
.unwrap();
}
});
}

#[test]
#[should_panic]
fn test_iter_key_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
for i in 0..10 {
v.insert(i * 2, i * 2);
}
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (i, (key, value)) in dict.iter().enumerate() {
let key = key.extract::<i32>().unwrap();
let value = value.extract::<i32>().unwrap();

dict.set_item(key + 1, value + 1).unwrap();

if i > 1000 {
// avoid this test just running out of memory if it fails
break;
};
}
});
}

#[test]
#[should_panic]
fn test_iter_key_mutated_constant_len() {
Python::with_gil(|py| {
let mut v = HashMap::new();
for i in 0..10 {
v.insert(i * 2, i * 2);
}
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (i, (key, value)) in dict.iter().enumerate() {
let key = key.extract::<i32>().unwrap();
let value = value.extract::<i32>().unwrap();
dict.del_item(key).unwrap();
dict.set_item(key + 1, value + 1).unwrap();

if i > 1000 {
// avoid this test just running out of memory if it fails
break;
};
}
});
}

#[test]
fn test_iter_size_hint() {
Python::with_gil(|py| {
Expand All @@ -675,10 +798,14 @@ mod tests {
iter.next();
assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1)));

// Exhust iterator.
// Exhaust iterator.
for _ in &mut iter {}

assert_eq!(iter.size_hint(), (0, Some(0)));

assert!(iter.next().is_none());

assert_eq!(iter.size_hint(), (0, Some(0)));
});
}

Expand Down