Skip to content

Commit

Permalink
Merge pull request #1901 from LaurentMazare/closures
Browse files Browse the repository at this point in the history
Support for wrapping rust closures as python functions
  • Loading branch information
davidhewitt committed Oct 17, 2021
2 parents 3b94f4b + 2042906 commit fbb5e3c
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849)
- Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860)
- Add `abi3-py310` feature. [#1889](https://github.com/PyO3/pyo3/pull/1889)
- Add `PyCFunction::new_closure` to create a Python function from a Rust closure. [#1901](https://github.com/PyO3/pyo3/pull/1901)

### Changed

Expand Down
117 changes: 105 additions & 12 deletions src/types/function.rs
Expand Up @@ -3,15 +3,58 @@ use crate::exceptions::PyValueError;
use crate::prelude::*;
use crate::{
class::methods::{self, PyMethodDef},
ffi, AsPyPointer,
ffi, types, AsPyPointer,
};
use std::os::raw::c_void;

/// Represents a builtin Python function object.
#[repr(transparent)]
pub struct PyCFunction(PyAny);

pyobject_native_type_core!(PyCFunction, ffi::PyCFunction_Type, #checkfunction=ffi::PyCFunction_Check);

const CLOSURE_CAPSULE_NAME: &[u8] = b"pyo3-closure\0";

unsafe extern "C" fn run_closure<F, R>(
capsule_ptr: *mut ffi::PyObject,
args: *mut ffi::PyObject,
kwargs: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
crate::callback_body!(py, {
let boxed_fn: &F =
&*(ffi::PyCapsule_GetPointer(capsule_ptr, CLOSURE_CAPSULE_NAME.as_ptr() as *const _)
as *mut F);
let args = py.from_borrowed_ptr::<types::PyTuple>(args);
let kwargs = py.from_borrowed_ptr_or_opt::<types::PyDict>(kwargs);
boxed_fn(args, kwargs)
})
}

unsafe extern "C" fn drop_closure<F, R>(capsule_ptr: *mut ffi::PyObject)
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let result = std::panic::catch_unwind(|| {
let boxed_fn: Box<F> = Box::from_raw(ffi::PyCapsule_GetPointer(
capsule_ptr,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
) as *mut F);
drop(boxed_fn)
});
if let Err(err) = result {
// This second layer of catch_unwind is useful as eprintln! can also panic.
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
eprintln!("--- PyO3 intercepted a panic when dropping a closure");
eprintln!("{:?}", err);
}));
}
}

impl PyCFunction {
/// Create a new built-in function with keywords.
pub fn new_with_keywords<'a>(
Expand Down Expand Up @@ -39,23 +82,57 @@ impl PyCFunction {
)
}

/// Create a new function from a closure.
///
/// # Examples
///
/// ```
/// # use pyo3::prelude::*;
/// # use pyo3::{py_run, types};
///
/// Python::with_gil(|py| {
/// let add_one = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
/// let i = args.extract::<(i64,)>()?.0;
/// Ok(i+1)
/// };
/// let add_one = types::PyCFunction::new_closure(add_one, py).unwrap();
/// py_run!(py, add_one, "assert add_one(42) == 43");
/// });
/// ```
pub fn new_closure<F, R>(f: F, py: Python) -> PyResult<&PyCFunction>
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let function_ptr = Box::into_raw(Box::new(f));
let capsule = unsafe {
PyObject::from_owned_ptr_or_err(
py,
ffi::PyCapsule_New(
function_ptr as *mut c_void,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
Some(drop_closure::<F, R>),
),
)?
};
let method_def = methods::PyMethodDef::cfunction_with_keywords(
"pyo3-closure",
methods::PyCFunctionWithKeywords(run_closure::<F, R>),
"",
);
Self::internal_new_from_pointers(method_def, py, capsule.as_ptr(), std::ptr::null_mut())
}

#[doc(hidden)]
pub fn internal_new(
fn internal_new_from_pointers(
method_def: PyMethodDef,
py_or_module: PyFunctionArguments,
py: Python,
mod_ptr: *mut ffi::PyObject,
module_name: *mut ffi::PyObject,
) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let def = method_def
.as_method_def()
.map_err(|err| PyValueError::new_err(err.0))?;
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

unsafe {
py.from_owned_ptr_or_err::<PyCFunction>(ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(def)),
Expand All @@ -64,6 +141,22 @@ impl PyCFunction {
))
}
}

#[doc(hidden)]
pub fn internal_new(
method_def: PyMethodDef,
py_or_module: PyFunctionArguments,
) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
Self::internal_new_from_pointers(method_def, py, mod_ptr, module_name)
}
}

/// Represents a Python function object.
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Expand Up @@ -23,6 +23,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_pymethods.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_argument_attributes.rs");
t.compile_fail("tests/ui/invalid_closure.rs");
t.compile_fail("tests/ui/reject_generics.rs");

tests_rust_1_48(&t);
Expand Down
56 changes: 55 additions & 1 deletion tests/test_pyfunction.rs
@@ -1,7 +1,7 @@
#[cfg(not(Py_LIMITED_API))]
use pyo3::buffer::PyBuffer;
use pyo3::prelude::*;
use pyo3::types::PyCFunction;
use pyo3::types::{self, PyCFunction};
#[cfg(not(Py_LIMITED_API))]
use pyo3::types::{PyDateTime, PyFunction};

Expand Down Expand Up @@ -213,3 +213,57 @@ fn test_conversion_error() {
"argument 'option_arg': 'str' object cannot be interpreted as an integer"
);
}

#[test]
fn test_closure() {
let gil = Python::acquire_gil();
let py = gil.python();

let f = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
let gil = Python::acquire_gil();
let py = gil.python();
let res: Vec<_> = args
.iter()
.map(|elem| {
if let Ok(i) = elem.extract::<i64>() {
(i + 1).into_py(py)
} else if let Ok(f) = elem.extract::<f64>() {
(2. * f).into_py(py)
} else if let Ok(mut s) = elem.extract::<String>() {
s.push_str("-py");
s.into_py(py)
} else {
panic!("unexpected argument type for {:?}", elem)
}
})
.collect();
Ok(res)
};
let closure_py = PyCFunction::new_closure(f, py).unwrap();

py_assert!(py, closure_py, "closure_py(42) == [43]");
py_assert!(
py,
closure_py,
"closure_py(42, 3.14, 'foo') == [43, 6.28, 'foo-py']"
);
}

#[test]
fn test_closure_counter() {
let gil = Python::acquire_gil();
let py = gil.python();

let counter = std::cell::RefCell::new(0);
let counter_fn =
move |_args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<i32> {
let mut counter = counter.borrow_mut();
*counter += 1;
Ok(*counter)
};
let counter_py = PyCFunction::new_closure(counter_fn, py).unwrap();

py_assert!(py, counter_py, "counter_py() == 1");
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}
19 changes: 19 additions & 0 deletions tests/ui/invalid_closure.rs
@@ -0,0 +1,19 @@
use pyo3::prelude::*;
use pyo3::types::{PyCFunction, PyDict, PyTuple};

fn main() {
let fun: Py<PyCFunction> = Python::with_gil(|py| {
let local_data = vec![0, 1, 2, 3, 4];
let ref_: &[u8] = &local_data;

let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
println!("This is five: {:?}", ref_.len());
Ok(())
};
PyCFunction::new_closure(closure_fn, py).unwrap().into()
});

Python::with_gil(|py| {
fun.call0(py).unwrap();
});
}
28 changes: 28 additions & 0 deletions tests/ui/invalid_closure.stderr
@@ -0,0 +1,28 @@
error[E0597]: `local_data` does not live long enough
--> tests/ui/invalid_closure.rs:7:27
|
7 | let ref_: &[u8] = &local_data;
| ^^^^^^^^^^^ borrowed value does not live long enough
...
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ---------------------------------------- argument requires that `local_data` is borrowed for `'static`
14 | });
| - `local_data` dropped here while still borrowed

error[E0373]: closure may outlive the current function, but it borrows `ref_`, which is owned by the current function
--> tests/ui/invalid_closure.rs:9:26
|
9 | let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `ref_`
10 | println!("This is five: {:?}", ref_.len());
| ---- `ref_` is borrowed here
|
note: function requires argument type to outlive `'static`
--> tests/ui/invalid_closure.rs:13:9
|
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
help: to force the closure to take ownership of `ref_` (and any other referenced variables), use the `move` keyword
|
9 | let closure_fn = move |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 comments on commit fbb5e3c

Please sign in to comment.