From e0219b44e58d1f22ab011c66a164e9e3a8bef742 Mon Sep 17 00:00:00 2001 From: milesgranger Date: Tue, 9 Nov 2021 20:21:40 +0100 Subject: [PATCH] Add PyCapsule API --- src/lib.rs | 1 + src/pycapsule.rs | 232 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/pycapsule.rs diff --git a/src/lib.rs b/src/lib.rs index 2d0ca39ba70..494d0143994 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -334,6 +334,7 @@ pub mod marshal; pub mod once_cell; pub mod panic; pub mod prelude; +pub mod pycapsule; pub mod pycell; pub mod pyclass; pub mod pyclass_init; diff --git a/src/pycapsule.rs b/src/pycapsule.rs new file mode 100644 index 00000000000..fa983c5cb9e --- /dev/null +++ b/src/pycapsule.rs @@ -0,0 +1,232 @@ +use crate::Python; +use crate::{ffi, AsPyPointer, PyAny}; +use crate::{pyobject_native_type_core, PyErr, PyResult}; +use std::ffi::{c_void, CStr}; +use std::os::raw::c_int; + +/// TODO: docs +/// +#[repr(transparent)] +pub struct PyCapsule(PyAny); + +pyobject_native_type_core!(PyCapsule, ffi::PyCapsule_Type, #checkfunction=ffi::PyCapsule_CheckExact); + +impl PyCapsule { + /// TODO: docs + pub fn new<'py, T>( + py: Python<'py>, + value: T, + name: &CStr, + destructor: Option, + ) -> PyResult<&'py Self> { + let val = Box::new(value); + + let cap_ptr = unsafe { + ffi::PyCapsule_New(Box::into_raw(val) as *mut c_void, name.as_ptr(), destructor) + }; + if cap_ptr.is_null() { + Err(PyErr::fetch(py)) + } else { + Ok(unsafe { py.from_owned_ptr::(cap_ptr) }) + } + } + + /// TODO: docs + pub fn import<'py, T>(py: Python<'py>, name: &CStr, no_block: bool) -> PyResult<&'py T> { + let ptr = unsafe { ffi::PyCapsule_Import(name.as_ptr(), no_block as c_int) }; + if ptr.is_null() { + Err(PyErr::fetch(py)) + } else { + Ok(unsafe { &*(ptr as *const T) }) + } + } + + /// TODO: docs + pub fn set_context(&self, py: Python, context: T) -> PyResult<()> { + let ctx = Box::new(context); + let result = + unsafe { ffi::PyCapsule_SetContext(self.as_ptr(), Box::into_raw(ctx) as _) as u8 }; + if result != 0 { + Err(PyErr::fetch(py)) + } else { + Ok(()) + } + } + + /// TODO: docs + pub fn get_context(&self, py: Python) -> PyResult> { + let ctx = unsafe { ffi::PyCapsule_GetContext(self.as_ptr()) }; + if ctx.is_null() { + if self.is_valid() & PyErr::occurred(py) { + Err(PyErr::fetch(py)) + } else { + Ok(None) + } + } else { + Ok(Some(unsafe { &*(ctx as *const T) })) + } + } + + /// TODO: docs + pub fn reference(&self) -> &T { + unsafe { &*(self.get_pointer() as *const T) } + } + + /// TODO: docs + pub fn get_pointer(&self) -> *mut c_void { + unsafe { ffi::PyCapsule_GetPointer(self.0.as_ptr(), self.name().as_ptr()) } + } + + /// TODO: docs + pub fn is_valid(&self) -> bool { + let r = unsafe { ffi::PyCapsule_IsValid(self.as_ptr(), self.name().as_ptr()) } as u8; + r != 0 + } + + /// TODO: docs + pub fn get_deconstructor(&self, py: Python) -> PyResult> { + match unsafe { ffi::PyCapsule_GetDestructor(self.as_ptr()) } { + Some(deconstructor) => Ok(Some(deconstructor)), + None => { + // A None can mean an error was raised, or there is no deconstructor + // https://docs.python.org/3/c-api/capsule.html#c.PyCapsule_GetDestructor + if self.is_valid() { + Ok(None) + } else { + Err(PyErr::fetch(py)) + } + } + } + } + + /// TODO: docs + pub fn name(&self) -> &CStr { + unsafe { + let ptr = ffi::PyCapsule_GetName(self.as_ptr()); + CStr::from_ptr(ptr) + } + } +} + +#[cfg(test)] +mod tests { + use crate::prelude::PyModule; + use crate::{ffi, pycapsule::PyCapsule, PyResult, Python}; + use std::ffi::{c_void, CString}; + use std::sync::mpsc::{channel, Sender}; + + #[test] + fn test_pycapsule_struct() -> PyResult<()> { + #[repr(C)] + struct Foo { + pub val: u32, + } + + impl Foo { + fn get_val(&self) -> u32 { + self.val + } + } + + Python::with_gil(|py| -> PyResult<()> { + let foo = Foo { val: 123 }; + let name = CString::new("foo").unwrap(); + + let cap = PyCapsule::new(py, foo, &name, None)?; + assert!(cap.is_valid()); + + let foo_capi = cap.reference::(); + assert_eq!(foo_capi.val, 123); + assert_eq!(foo_capi.get_val(), 123); + assert_eq!(cap.name(), name.as_ref()); + Ok(()) + }) + } + + #[test] + fn test_pycapsule_func() -> PyResult<()> { + extern "C" fn foo(x: u32) -> u32 { + x + } + + Python::with_gil(|py| { + let name = CString::new("foo").unwrap(); + + let cap = PyCapsule::new(py, foo as *const c_void, &name, None)?; + let f = cap.reference:: u32>(); + assert_eq!(f(123), 123); + Ok(()) + }) + } + + #[test] + fn test_pycapsule_context() -> PyResult<()> { + Python::with_gil(|py| { + let name = CString::new("foo").unwrap(); + let cap = PyCapsule::new(py, (), &name, None)?; + + let c = cap.get_context::<()>(py)?; + assert!(c.is_none()); + + cap.set_context(py, 123)?; + + let ctx: Option<&u32> = cap.get_context(py)?; + assert_eq!(ctx, Some(&123)); + Ok(()) + }) + } + + #[test] + fn test_pycapsule_import() -> PyResult<()> { + #[repr(C)] + struct Foo { + pub val: u32, + } + + Python::with_gil(|py| -> PyResult<()> { + let foo = Foo { val: 123 }; + let name = CString::new("builtins.capsule").unwrap(); + + let capsule = PyCapsule::new(py, foo, &name, None)?; + + let module = PyModule::import(py, "builtins")?; + module.add("capsule", capsule)?; + + let path = CString::new("builtins.capsule").unwrap(); + let cap: &Foo = PyCapsule::import(py, path.as_ref(), false)?; + assert_eq!(cap.val, 123); + Ok(()) + }) + } + + #[test] + fn test_pycapsule_destructor() { + #[repr(C)] + struct Foo { + called: Sender, + } + + let (tx, rx) = channel(); + + // Setup destructor, call sender to notify of being called + unsafe extern "C" fn destructor(ptr: *mut ffi::PyObject) { + Python::with_gil(|py| { + let cap = py.from_borrowed_ptr::(ptr); + let foo = cap.reference::(); + foo.called.send(true).unwrap(); + }) + } + + // Create a capsule and allow it to be freed. + let r = Python::with_gil(|py| -> PyResult<()> { + let foo = Foo { called: tx }; + let name = CString::new("builtins.capsule").unwrap(); + let _capsule = PyCapsule::new(py, foo, &name, Some(destructor))?; + Ok(()) + }); + assert!(r.is_ok()); + + // Indeed it was + assert_eq!(rx.recv(), Ok(true)); + } +}