Skip to content

Commit

Permalink
Set the module of #[pyfunction]s.
Browse files Browse the repository at this point in the history
Previously neither the module nor the name of the module of
pyfunctions were registered. This commit passes the module and
its name when creating a new pyfunction.

PyModule::add_function and PyModule::add_module have been added and are
set to replace `add_wrapped` in a future release. `add_wrapped` is kept
for compatibility reasons during the transition.

Depending on whether a `PyModule` or `Python` is the argument for the
Python function-wrapper, the module will be registered with the function.
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 21ad52a commit e34294a
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 37 deletions.
28 changes: 14 additions & 14 deletions examples/rustapi_module/src/datetime.rs
Expand Up @@ -215,29 +215,29 @@ impl TzClass {

#[pymodule]
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(make_date))?;
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
m.add_wrapped(wrap_pyfunction!(make_time))?;
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_delta))?;
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_date))?;
m.add_function(wrap_pyfunction!(get_date_tuple))?;
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_time))?;
m.add_function(wrap_pyfunction!(get_time_tuple))?;
m.add_function(wrap_pyfunction!(make_delta))?;
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
m.add_function(wrap_pyfunction!(make_datetime))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;

// Python 3.6+ functions
#[cfg(Py_3_6)]
{
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
m.add_function(wrap_pyfunction!(time_with_fold))?;
#[cfg(not(PyPy))]
{
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
}
}

m.add_wrapped(wrap_pyfunction!(issue_219))?;
m.add_function(wrap_pyfunction!(issue_219))?;
m.add_class::<TzClass>()?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/rustapi_module/src/othermod.rs
Expand Up @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {

#[pymodule]
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_function(wrap_pyfunction!(double))?;

m.add_class::<ModClass>()?;

Expand Down
6 changes: 3 additions & 3 deletions examples/word-count/src/lib.rs
Expand Up @@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize {

#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
m.add_function(wrap_pyfunction!(search))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

Ok(())
}
4 changes: 2 additions & 2 deletions guide/src/function.md
Expand Up @@ -36,7 +36,7 @@ fn double(x: usize) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions guide/src/module.md
Expand Up @@ -67,13 +67,13 @@ fn subfunction() -> String {

#[pymodule]
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

#[pymodule]
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down
24 changes: 20 additions & 4 deletions pyo3-derive-backend/src/module.rs
Expand Up @@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
let item: syn::ItemFn = syn::parse_quote! {
fn block_wrapper() {
#function_to_python
#module_name.add_wrapped(&#function_wrapper_ident)?;
#module_name.add_function(&#function_wrapper_ident)?;
}
};
stmts.extend(item.block.stmts.into_iter());
Expand Down Expand Up @@ -193,7 +193,10 @@ pub fn add_fn_to_module(
let wrapper = function_c_wrapper(&func.sig.ident, &spec);

Ok(quote! {
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
fn #function_wrapper_ident<'a, 'b>(
args: impl pyo3::derive_utils::WrapPyFunctionArguments<'a, 'b>
) -> pyo3::PyObject {
let (py, maybe_module) = args.arguments();
#wrapper

let _def = pyo3::class::PyMethodDef {
Expand All @@ -203,12 +206,25 @@ pub fn add_fn_to_module(
ml_doc: #doc,
};

let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = unsafe { pyo3::ffi::PyModule_GetNameObject(mod_ptr) };
if name.is_null() {
let err = PyErr::fetch(py);
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
}
(mod_ptr, name)
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

let function = unsafe {
pyo3::PyObject::from_owned_ptr(
py,
pyo3::ffi::PyCFunction_New(
pyo3::ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(_def.as_method_def())),
::std::ptr::null_mut()
mod_ptr,
name
)
)
};
Expand Down
21 changes: 21 additions & 0 deletions src/derive_utils.rs
Expand Up @@ -207,3 +207,24 @@ where
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
}
}

/// Trait to abstract over the arguments of Python function wrappers.
#[doc(hidden)]
pub trait WrapPyFunctionArguments<'a, 'b> {
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>);
}

impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for Python<'b> {
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
(self, None)
}
}

impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for &'a PyModule
where
'a: 'b,
{
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
(self.py(), Some(self))
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Expand Up @@ -71,7 +71,7 @@
//! #[pymodule]
//! /// A Python module implemented in Rust.
//! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
//! m.add_function(wrap_pyfunction!(sum_as_string))?;
//!
//! Ok(())
//! }
Expand Down
2 changes: 1 addition & 1 deletion src/python.rs
Expand Up @@ -134,7 +134,7 @@ impl<'p> Python<'p> {
/// let gil = Python::acquire_gil();
/// let py = gil.python();
/// let m = PyModule::new(py, "pcount").unwrap();
/// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap();
/// m.add_function(wrap_pyfunction!(parallel_count)).unwrap();
/// let locals = [("pcount", m)].into_py_dict(py);
/// py.run(r#"
/// s = ["Flow", "my", "tears", "the", "Policeman", "Said"]
Expand Down
39 changes: 39 additions & 0 deletions src/types/module.rs
Expand Up @@ -194,11 +194,50 @@ impl PyModule {
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py));
/// ```
///
/// **This function will be deprecated in the next release. Please use the specific
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a (sub)module to a module.
///
/// Use this together with `#[pymodule]` and [wrap_pymodule!].
///
/// ```rust,ignore
/// m.add_module(wrap_pymodule!(utils));
/// ```
pub fn add_module(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a function to a module, using the functions __name__ as name.
///
/// Use this together with the`#[pyfunction]` and [wrap_pyfunction!].
///
/// ```rust,ignore
/// m.add_function(wrap_pyfunction!(double));
/// ```
///
/// You can also add a function with a custom name using [add](PyModule::add):
///
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}
}
18 changes: 9 additions & 9 deletions tests/test_module.rs
Expand Up @@ -35,7 +35,7 @@ fn double(x: usize) -> usize {

/// This module is implemented in Rust.
#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

#[pyfn(m, "sum_as_string")]
Expand All @@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {

m.add("foo", "bar").unwrap();

m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(py)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -157,7 +157,7 @@ fn r#move() -> usize {
fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(r#move))
module.add_function(wrap_pyfunction!(r#move))
}

#[test]
Expand All @@ -182,7 +182,7 @@ fn custom_named_fn() -> usize {
fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

m.add_wrapped(wrap_pyfunction!(custom_named_fn))?;
m.add_function(wrap_pyfunction!(custom_named_fn))?;
m.dict().set_item("yay", "me")?;
Ok(())
}
Expand Down Expand Up @@ -216,7 +216,7 @@ fn subfunction() -> String {
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

Expand All @@ -229,8 +229,8 @@ fn superfunction() -> String {
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::{wrap_pyfunction, wrap_pymodule};

module.add_wrapped(wrap_pyfunction!(superfunction))?;
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_function(wrap_pyfunction!(superfunction))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down Expand Up @@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> {
ext_vararg_fn(py, a, vararg)
}

m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn))
m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn))
.unwrap();
Ok(())
}
Expand Down

0 comments on commit e34294a

Please sign in to comment.