diff --git a/examples/rustapi_module/src/datetime.rs b/examples/rustapi_module/src/datetime.rs index 3181ae79d97..3ccb7c697f6 100644 --- a/examples/rustapi_module/src/datetime.rs +++ b/examples/rustapi_module/src/datetime.rs @@ -215,29 +215,29 @@ impl TzClass { #[pymodule] fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(make_date))?; - m.add_wrapped(wrap_pyfunction!(get_date_tuple))?; - m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?; - m.add_wrapped(wrap_pyfunction!(make_time))?; - m.add_wrapped(wrap_pyfunction!(get_time_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_delta))?; - m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_datetime))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?; - m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_date))?; + m.add_function(wrap_pyfunction!(get_date_tuple))?; + m.add_function(wrap_pyfunction!(date_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_time))?; + m.add_function(wrap_pyfunction!(get_time_tuple))?; + m.add_function(wrap_pyfunction!(make_delta))?; + m.add_function(wrap_pyfunction!(get_delta_tuple))?; + m.add_function(wrap_pyfunction!(make_datetime))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple))?; + m.add_function(wrap_pyfunction!(datetime_from_timestamp))?; // Python 3.6+ functions #[cfg(Py_3_6)] { - m.add_wrapped(wrap_pyfunction!(time_with_fold))?; + m.add_function(wrap_pyfunction!(time_with_fold))?; #[cfg(not(PyPy))] { - m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_time_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?; } } - m.add_wrapped(wrap_pyfunction!(issue_219))?; + m.add_function(wrap_pyfunction!(issue_219))?; m.add_class::()?; Ok(()) diff --git a/examples/rustapi_module/src/othermod.rs b/examples/rustapi_module/src/othermod.rs index 20745b29fb6..b9955806186 100644 --- a/examples/rustapi_module/src/othermod.rs +++ b/examples/rustapi_module/src/othermod.rs @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 { #[pymodule] fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double))?; + m.add_function(wrap_pyfunction!(double))?; m.add_class::()?; diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index 06d696e895f..8d65199c8bf 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize { #[pymodule] fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(search))?; - m.add_wrapped(wrap_pyfunction!(search_sequential))?; - m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?; + m.add_function(wrap_pyfunction!(search))?; + m.add_function(wrap_pyfunction!(search_sequential))?; + m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?; Ok(()) } diff --git a/guide/src/function.md b/guide/src/function.md index 1a12d8ec6f1..b8167c0b7cb 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -36,7 +36,7 @@ fn double(x: usize) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); Ok(()) } @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap(); + m.add_function(wrap_pyfunction!(num_kwds)).unwrap(); Ok(()) } diff --git a/guide/src/module.md b/guide/src/module.md index 4dea21b1b9b..042b11f0178 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -67,13 +67,13 @@ fn subfunction() -> String { #[pymodule] fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } #[pymodule] fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pymodule!(submodule))?; + module.add_module(wrap_pymodule!(submodule))?; Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 9a51a1e5b95..84ac388ae9a 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { #function_to_python - #module_name.add_wrapped(&#function_wrapper_ident)?; + #module_name.add_function(&#function_wrapper_ident)?; } }; stmts.extend(item.block.stmts.into_iter()); @@ -193,7 +193,10 @@ pub fn add_fn_to_module( let wrapper = function_c_wrapper(&func.sig.ident, &spec); Ok(quote! { - fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject { + fn #function_wrapper_ident<'a, 'b>( + args: impl pyo3::derive_utils::WrapPyFunctionArguments<'a, 'b> + ) -> pyo3::PyObject { + let (py, maybe_module) = args.arguments(); #wrapper let _def = pyo3::class::PyMethodDef { @@ -203,12 +206,25 @@ pub fn add_fn_to_module( ml_doc: #doc, }; + let (mod_ptr, name) = if let Some(m) = maybe_module { + let mod_ptr = ::as_ptr(m); + let name = unsafe { pyo3::ffi::PyModule_GetNameObject(mod_ptr) }; + if name.is_null() { + let err = PyErr::fetch(py); + return >::into_py(err, py); + } + (mod_ptr, name) + } else { + (std::ptr::null_mut(), std::ptr::null_mut()) + }; + let function = unsafe { pyo3::PyObject::from_owned_ptr( py, - pyo3::ffi::PyCFunction_New( + pyo3::ffi::PyCFunction_NewEx( Box::into_raw(Box::new(_def.as_method_def())), - ::std::ptr::null_mut() + mod_ptr, + name ) ) }; diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 2a736d7ebcd..00fdd097c92 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -207,3 +207,24 @@ where >>::try_from(cell) } } + +/// Trait to abstract over the arguments of Python function wrappers. +#[doc(hidden)] +pub trait WrapPyFunctionArguments<'a, 'b> { + fn arguments(self) -> (Python<'b>, Option<&'a PyModule>); +} + +impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for Python<'b> { + fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) { + (self, None) + } +} + +impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for &'a PyModule +where + 'a: 'b, +{ + fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) { + (self.py(), Some(self)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 10f3e768f8e..4c2313e3a47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ //! #[pymodule] //! /// A Python module implemented in Rust. //! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { -//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?; +//! m.add_function(wrap_pyfunction!(sum_as_string))?; //! //! Ok(()) //! } diff --git a/src/python.rs b/src/python.rs index 901a426bcf0..db4abfe6b1a 100644 --- a/src/python.rs +++ b/src/python.rs @@ -134,7 +134,7 @@ impl<'p> Python<'p> { /// let gil = Python::acquire_gil(); /// let py = gil.python(); /// let m = PyModule::new(py, "pcount").unwrap(); - /// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap(); + /// m.add_function(wrap_pyfunction!(parallel_count)).unwrap(); /// let locals = [("pcount", m)].into_py_dict(py); /// py.run(r#" /// s = ["Flow", "my", "tears", "the", "Policeman", "Said"] diff --git a/src/types/module.rs b/src/types/module.rs index b345fcab563..28936368561 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -194,6 +194,9 @@ impl PyModule { /// ```rust,ignore /// m.add("also_double", wrap_pyfunction!(double)(py)); /// ``` + /// + /// **This function will be deprecated in the next release. Please use the specific + /// [add_function] and [add_module] functions instead.** pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> { let function = wrapper(self.py()); let name = function @@ -201,4 +204,40 @@ impl PyModule { .expect("A function or module must have a __name__"); self.add(name.extract(self.py()).unwrap(), function) } + + /// Adds a (sub)module to a module. + /// + /// Use this together with `#[pymodule]` and [wrap_pymodule!]. + /// + /// ```rust,ignore + /// m.add_module(wrap_pymodule!(utils)); + /// ``` + pub fn add_module(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> { + let function = wrapper(self.py()); + let name = function + .getattr(self.py(), "__name__") + .expect("A module must have a __name__"); + self.add(name.extract(self.py()).unwrap(), function) + } + + /// Adds a function to a module, using the functions __name__ as name. + /// + /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!]. + /// + /// ```rust,ignore + /// m.add_function(wrap_pyfunction!(double)); + /// ``` + /// + /// You can also add a function with a custom name using [add](PyModule::add): + /// + /// ```rust,ignore + /// m.add("also_double", wrap_pyfunction!(double)(py, m)); + /// ``` + pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> { + let function = wrapper(self); + let name = function + .getattr(self.py(), "__name__") + .expect("A function or module must have a __name__"); + self.add(name.extract(self.py()).unwrap(), function) + } } diff --git a/tests/test_module.rs b/tests/test_module.rs index 0746fb8f868..f3c44667f06 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -35,7 +35,7 @@ fn double(x: usize) -> usize { /// This module is implemented in Rust. #[pymodule] -fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { +fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; #[pyfn(m, "sum_as_string")] @@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar").unwrap(); - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); - m.add("also_double", wrap_pyfunction!(double)(py)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); + m.add("also_double", wrap_pyfunction!(double)(m)).unwrap(); Ok(()) } @@ -157,7 +157,7 @@ fn r#move() -> usize { fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(r#move)) + module.add_function(wrap_pyfunction!(r#move)) } #[test] @@ -182,7 +182,7 @@ fn custom_named_fn() -> usize { fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - m.add_wrapped(wrap_pyfunction!(custom_named_fn))?; + m.add_function(wrap_pyfunction!(custom_named_fn))?; m.dict().set_item("yay", "me")?; Ok(()) } @@ -216,7 +216,7 @@ fn subfunction() -> String { fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } @@ -229,8 +229,8 @@ fn superfunction() -> String { fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::{wrap_pyfunction, wrap_pymodule}; - module.add_wrapped(wrap_pyfunction!(superfunction))?; - module.add_wrapped(wrap_pymodule!(submodule))?; + module.add_function(wrap_pyfunction!(superfunction))?; + module.add_module(wrap_pymodule!(submodule))?; Ok(()) } @@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> { ext_vararg_fn(py, a, vararg) } - m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn)) + m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn)) .unwrap(); Ok(()) }