From be7c62519ef5717af2c8cbeb1be77f679f381286 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 17 Nov 2023 20:33:41 -0800 Subject: [PATCH] feat: add PyCapsule::import_mut() --- newsfragments/3585.added.md | 1 + src/types/capsule.rs | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 newsfragments/3585.added.md diff --git a/newsfragments/3585.added.md b/newsfragments/3585.added.md new file mode 100644 index 00000000000..8518c0ceb1a --- /dev/null +++ b/newsfragments/3585.added.md @@ -0,0 +1 @@ +Add `PyCapsule::import_mut()` to import a mutable reference to capsule contents. diff --git a/src/types/capsule.rs b/src/types/capsule.rs index f97320eb474..6593acfc1f3 100644 --- a/src/types/capsule.rs +++ b/src/types/capsule.rs @@ -138,6 +138,26 @@ impl PyCapsule { } } + /// Imports an existing capsule, returning a mutable reference to the value. + /// + /// The `name` should match the path to the module attribute exactly in the form + /// of `"module.attribute"`, which should be the same as the name within the capsule. + /// + /// # Safety + /// + /// It must be known that the capsule imported by `name` contains an item of type `T`. + /// + /// It must be known that the capsule will not be accessed by any other code + /// while the mutable reference is alive. + pub unsafe fn import_mut<'py, T>(py: Python<'py>, name: &CStr) -> PyResult<&'py mut T> { + let ptr = ffi::PyCapsule_Import(name.as_ptr(), false as c_int); + if ptr.is_null() { + Err(PyErr::fetch(py)) + } else { + Ok(&mut *(ptr as *mut T)) + } + } + /// Sets the context pointer in the capsule. /// /// Returns an error if this capsule is not valid. @@ -398,10 +418,16 @@ mod tests { let wrong_name = CString::new("builtins.non_existant").unwrap(); let result: PyResult<&Foo> = unsafe { PyCapsule::import(py, wrong_name.as_ref()) }; assert!(result.is_err()); + let result: PyResult<&mut Foo> = + unsafe { PyCapsule::import_mut(py, wrong_name.as_ref()) }; + assert!(result.is_err()); - // corret name is okay. + // correct name is okay. let cap: &Foo = unsafe { PyCapsule::import(py, name.as_ref())? }; assert_eq!(cap.val, 123); + let cap: &mut Foo = unsafe { PyCapsule::import_mut(py, name.as_ref())? }; + assert_eq!(cap.val, 123); + Ok(()) }) }