Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect iterators against concurrent modification #2380

Merged
merged 8 commits into from May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
199 changes: 160 additions & 39 deletions src/types/dict.rs
Expand Up @@ -213,13 +213,13 @@ impl PyDict {

/// Returns an iterator of `(key, value)` pairs in this dictionary.
///
/// Note that it's unsafe to use when the dictionary might be changed by
/// other code.
/// # Panics
///
/// If PyO3 detects that the dictionary is mutated during iteration, it will panic.
/// It is allowed to modify the values of the keys as you iterate over the dictionary, but only
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know "modify the values of the keys" is copied verbatim from the C-API docs, but I think it's quite confusing.

/// so long as the set of keys does not change.
pub fn iter(&self) -> PyDictIterator<'_> {
PyDictIterator {
dict: self.as_ref(),
pos: 0,
}
IntoIterator::into_iter(self)
}

/// Returns `self` cast as a `PyMapping`.
Expand All @@ -228,51 +228,129 @@ impl PyDict {
}
}

pub struct PyDictIterator<'py> {
dict: &'py PyAny,
pos: isize,
}
#[cfg(all(Py_3_8, not(Py_LIMITED_API)))]
mod impl_ {
use super::*;
use std::marker::PhantomData;

pub struct PyDictIterator<'py> {
di_dict: *mut ffi::PyDictObject,
di_pos: ffi::Py_ssize_t,
di_used: ffi::Py_ssize_t,
marker: PhantomData<&'py PyDict>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a risk that the dict refcount can fall to 0 and can be accidentally released? Should this be Py<PyDict>?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iterator borrows from the Pydict, so this should be fine.

}

impl<'py> Iterator for PyDictIterator<'py> {
type Item = (&'py PyAny, &'py PyAny);

/// Advances the iterator and returns the next value.
///
/// # Panics
///
/// If PyO3 detects that the dictionary is mutated during iteration, it will panic.
/// It is allowed to modify the values of the keys as you iterate over the dictionary, but only
/// so long as the set of keys does not change.
#[inline]
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.di_used != (*(self.di_dict)).ma_used {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between ma_used and PyDict_Size used below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...I read the cpython source some more: actually none. Let me reconsider the approach.

self.di_used = -1;
panic!("dictionary changed size during iteration");
mejrs marked this conversation as resolved.
Show resolved Hide resolved
};

let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();
if ffi::PyDict_Next(self.di_dict.cast(), &mut self.di_pos, &mut key, &mut value)
!= 0
{
let py = Python::assume_gil_acquired();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
}
}
}

impl<'py> Iterator for PyDictIterator<'py> {
type Item = (&'py PyAny, &'py PyAny);
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = unsafe { ffi::PyDict_Size(self.di_dict.cast()) as usize };
(
len.saturating_sub(self.di_pos as usize),
Some(len.saturating_sub(self.di_pos as usize)),
)
}
}

#[inline]
fn next(&mut self) -> Option<Self::Item> {
unsafe {
let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();
if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.pos, &mut key, &mut value) != 0 {
let py = self.dict.py();
// PyDict_Next returns borrowed values; for safety must make them owned (see #890)
Some((
py.from_owned_ptr(ffi::_Py_NewRef(key)),
py.from_owned_ptr(ffi::_Py_NewRef(value)),
))
} else {
None
impl<'a> std::iter::IntoIterator for &'a PyDict {
type Item = (&'a PyAny, &'a PyAny);
type IntoIter = PyDictIterator<'a>;

fn into_iter(self) -> Self::IntoIter {
unsafe {
let di_dict: *mut ffi::PyDictObject = self.as_ptr().cast();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've always wondered if there should be an as_raw() -> *mut ffi::PyDictObject method on PyDict (and similar for other types)? Available only on not (limited_api / PyPy).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds interesting, but I'm not sure how often this is useful in practice. You only need it to access fields directly, which we don't tend to do.


PyDictIterator {
di_dict,
di_pos: 0,
di_used: (*di_dict).ma_used,
marker: PhantomData,
}
}
}
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.dict.len().unwrap_or_default();
(
len.saturating_sub(self.pos as usize),
Some(len.saturating_sub(self.pos as usize)),
)
#[cfg(any(not(Py_3_8), Py_LIMITED_API))]
mod impl_ {
use super::*;
use crate::types::PyIterator;

pub struct PyDictIterator<'py> {
iter: &'py PyIterator,
}

impl<'py> Iterator for PyDictIterator<'py> {
type Item = (&'py PyAny, &'py PyAny);

/// Advances the iterator and returns the next value.
///
/// # Panics
///
/// If PyO3 detects that the dictionary is mutated during iteration, it will panic.
/// It is allowed to modify the values of the keys as you iterate over the dictionary, but only
/// so long as the set of keys does not change.
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.iter
.next()
.map(Result::unwrap)
.map(FromPyObject::extract)
.map(Result::unwrap)
}
}
}

impl<'a> std::iter::IntoIterator for &'a PyDict {
type Item = (&'a PyAny, &'a PyAny);
type IntoIter = PyDictIterator<'a>;
impl<'a> std::iter::IntoIterator for &'a PyDict {
type Item = (&'a PyAny, &'a PyAny);
type IntoIter = PyDictIterator<'a>;

fn into_iter(self) -> Self::IntoIter {
let py = self.py();

// `_PyDictView_New` is not available, so just do the `items()` call
let items = self.call_method0(intern!(py, "items")).unwrap();
let iter = PyIterator::from_object(py, items).unwrap();

fn into_iter(self) -> Self::IntoIter {
self.iter()
PyDictIterator { iter }
}
}
}

use impl_::*;

impl<K, V, H> ToPyObject for collections::HashMap<K, V, H>
where
K: hash::Hash + cmp::Eq + ToPyObject,
Expand Down Expand Up @@ -660,6 +738,49 @@ mod tests {
});
}

#[test]
fn test_iter_value_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
v.insert(7, 32);
v.insert(8, 42);
v.insert(9, 123);

let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (key, value) in dict.iter() {
dict.set_item(key, value.extract::<i32>().unwrap() + 7)
.unwrap();
}
});
}

#[test]
#[should_panic]
fn test_iter_key_mutated() {
Python::with_gil(|py| {
let mut v = HashMap::new();
for i in 0..1000 {
v.insert(i * 2, i * 2);
}
let ob = v.to_object(py);
let dict = <PyDict as PyTryFrom>::try_from(ob.as_ref(py)).unwrap();

for (i, (key, value)) in dict.iter().enumerate() {
let key = key.extract::<i32>().unwrap();
let value = value.extract::<i32>().unwrap();

dict.set_item(key + 1, value + 1).unwrap();

if i > 10000 {
break;
};
}
});
}

#[cfg(not(Py_LIMITED_API))]
#[test]
fn test_iter_size_hint() {
Python::with_gil(|py| {
Expand Down