diff --git a/newsfragments/3585.added.md b/newsfragments/3585.added.md new file mode 100644 index 00000000000..8518c0ceb1a --- /dev/null +++ b/newsfragments/3585.added.md @@ -0,0 +1 @@ +Add `PyCapsule::import_mut()` to import a mutable reference to capsule contents. diff --git a/src/types/capsule.rs b/src/types/capsule.rs index f97320eb474..6380197a496 100644 --- a/src/types/capsule.rs +++ b/src/types/capsule.rs @@ -123,8 +123,9 @@ impl PyCapsule { /// Imports an existing capsule. /// - /// The `name` should match the path to the module attribute exactly in the form - /// of `"module.attribute"`, which should be the same as the name within the capsule. + /// If this capsule represents a module attribute, the `name` should match the path + /// to the module attribute exactly in the form of `"module.attribute"`, which should + /// be the same as the name within the capsule. /// /// # Safety /// @@ -138,6 +139,23 @@ impl PyCapsule { } } + /// Imports an existing capsule, returning a mutable reference to the value. + /// + /// # Safety + /// + /// It must be known that the capsule imported by `name` contains an item of type `T`. + /// + /// It must be known that the capsule will not be accessed by any other code + /// while the mutable reference is alive. + pub unsafe fn import_mut<'py, T>(py: Python<'py>, name: &CStr) -> PyResult<&'py mut T> { + let ptr = ffi::PyCapsule_Import(name.as_ptr(), false as c_int); + if ptr.is_null() { + Err(PyErr::fetch(py)) + } else { + Ok(&mut *(ptr as *mut T)) + } + } + /// Sets the context pointer in the capsule. /// /// Returns an error if this capsule is not valid. @@ -398,10 +416,16 @@ mod tests { let wrong_name = CString::new("builtins.non_existant").unwrap(); let result: PyResult<&Foo> = unsafe { PyCapsule::import(py, wrong_name.as_ref()) }; assert!(result.is_err()); + let result: PyResult<&mut Foo> = + unsafe { PyCapsule::import_mut(py, wrong_name.as_ref()) }; + assert!(result.is_err()); - // corret name is okay. + // correct name is okay. let cap: &Foo = unsafe { PyCapsule::import(py, name.as_ref())? }; assert_eq!(cap.val, 123); + let cap: &mut Foo = unsafe { PyCapsule::import_mut(py, name.as_ref())? }; + assert_eq!(cap.val, 123); + Ok(()) }) }