From 1aa1e91ce6790468515674917ea055c75e5e8abd Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Fri, 26 Feb 2021 00:55:49 +0000 Subject: [PATCH] pycfunction: take &'static str arguments to new Co-authored-by: messense --- CHANGELOG.md | 1 + pyo3-macros-backend/src/module.rs | 15 ++++---- src/class/methods.rs | 53 +++++++++++++++---------- src/pyclass.rs | 2 +- src/types/function.rs | 64 +++++++++++-------------------- 5 files changed, 65 insertions(+), 70 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f67ed39d0f1..99b7959e0d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Deprecate FFI definition `PyCFunction_Call` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425) - Deprecate FFI definitions `PyModule_GetFilename`. [#1425](https://github.com/PyO3/pyo3/pull/1425) - The `auto-initialize` feature is no longer enabled by default. [#1443](https://github.com/PyO3/pyo3/pull/1443) +- Change `PyCFunction::new()` and `PyCFunction::new_with_keywords()` to take `&'static str` arguments rather than implicitly copying (and leaking) them. [#1450](https://github.com/PyO3/pyo3/pull/1450) ### Removed - Remove deprecated exception names `BaseException` etc. [#1426](https://github.com/PyO3/pyo3/pull/1426) diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index bb1e1bda918..d0fac0f18e2 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -192,8 +192,7 @@ pub fn add_fn_to_module( doc, }; - let doc = syn::LitByteStr::new(spec.doc.value().as_bytes(), spec.doc.span()); - + let doc = &spec.doc; let python_name = &spec.python_name; let name = &func.sig.ident; @@ -205,13 +204,13 @@ pub fn add_fn_to_module( args: impl Into> ) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> { let name = concat!(stringify!(#python_name), "\0"); - let name = std::ffi::CStr::from_bytes_with_nul(name.as_bytes()).unwrap(); - let doc = std::ffi::CStr::from_bytes_with_nul(#doc).unwrap(); pyo3::types::PyCFunction::internal_new( - name, - doc, - unsafe { std::mem::transmute(#wrapper_ident as *const std::os::raw::c_void) }, - pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, + pyo3::class::methods::PyMethodDef::cfunction_with_keywords( + name, + pyo3::class::methods::PyCFunctionWithKeywords(#wrapper_ident), + 0, + #doc, + ), args.into(), ) } diff --git a/src/class/methods.rs b/src/class/methods.rs index 6198183ca02..df0d5a012c6 100644 --- a/src/class/methods.rs +++ b/src/class/methods.rs @@ -1,7 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::{ffi, PyObject, Python}; -use std::ffi::CStr; +use std::ffi::{CStr, CString}; use std::fmt; use std::os::raw::c_int; @@ -73,15 +73,6 @@ unsafe impl Sync for PyGetterDef {} unsafe impl Sync for PySetterDef {} -fn get_name(name: &str) -> &CStr { - CStr::from_bytes_with_nul(name.as_bytes()) - .expect("Method name must be terminated with NULL byte") -} - -fn get_doc(doc: &str) -> &CStr { - CStr::from_bytes_with_nul(doc.as_bytes()).expect("Document must be terminated with NULL byte") -} - impl PyMethodDef { /// Define a function with no `*args` and `**kwargs`. pub const fn cfunction(name: &'static str, cfunction: PyCFunction, doc: &'static str) -> Self { @@ -109,18 +100,18 @@ impl PyMethodDef { } /// Convert `PyMethodDef` to Python method definition struct `ffi::PyMethodDef` - pub fn as_method_def(&self) -> ffi::PyMethodDef { + pub(crate) fn as_method_def(&self) -> Result { let meth = match self.ml_meth { PyMethodType::PyCFunction(meth) => meth.0, PyMethodType::PyCFunctionWithKeywords(meth) => unsafe { std::mem::transmute(meth.0) }, }; - ffi::PyMethodDef { - ml_name: get_name(self.ml_name).as_ptr(), + Ok(ffi::PyMethodDef { + ml_name: get_name(self.ml_name)?.as_ptr(), ml_meth: Some(meth), ml_flags: self.ml_flags, - ml_doc: get_doc(self.ml_doc).as_ptr(), - } + ml_doc: get_doc(self.ml_doc)?.as_ptr(), + }) } } @@ -128,7 +119,7 @@ impl PyClassAttributeDef { /// Define a class attribute. pub fn new(name: &'static str, meth: for<'p> fn(Python<'p>) -> PyObject) -> Self { Self { - name: get_name(name), + name: get_name(name).unwrap(), meth, } } @@ -148,9 +139,9 @@ impl PyGetterDef { /// Define a getter. pub fn new(name: &'static str, getter: ffi::getter, doc: &'static str) -> Self { Self { - name: get_name(name), + name: get_name(name).unwrap(), meth: getter, - doc: get_doc(doc), + doc: get_doc(doc).unwrap(), } } @@ -170,9 +161,9 @@ impl PySetterDef { /// Define a setter. pub fn new(name: &'static str, setter: ffi::setter, doc: &'static str) -> Self { Self { - name: get_name(name), + name: get_name(name).unwrap(), meth: setter, - doc: get_doc(doc), + doc: get_doc(doc).unwrap(), } } @@ -209,3 +200,25 @@ pub trait PyMethodsInventory: inventory::Collect { pub trait HasMethodsInventory { type Methods: PyMethodsInventory; } + +#[derive(Debug)] +pub(crate) struct NulByteInString(pub(crate) &'static str); + +fn get_name(name: &'static str) -> Result<&'static CStr, NulByteInString> { + extract_cstr_or_leak_cstring(name, "Function name cannot contain NUL byte.") +} + +fn get_doc(doc: &'static str) -> Result<&'static CStr, NulByteInString> { + extract_cstr_or_leak_cstring(doc, "Document cannot contain NUL byte.") +} + +fn extract_cstr_or_leak_cstring( + src: &'static str, + err_msg: &'static str, +) -> Result<&'static CStr, NulByteInString> { + CStr::from_bytes_with_nul(src.as_bytes()) + .or_else(|_| { + CString::new(src.as_bytes()).map(|c_string| &*Box::leak(c_string.into_boxed_c_str())) + }) + .map_err(|_| NulByteInString(err_msg)) +} diff --git a/src/pyclass.rs b/src/pyclass.rs index 3a0a20b683b..f613b1a1c84 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -318,7 +318,7 @@ fn py_class_method_defs() -> Vec { PyMethodDefType::Method(def) | PyMethodDefType::Class(def) | PyMethodDefType::Static(def) => { - defs.push(def.as_method_def()); + defs.push(def.as_method_def().unwrap()); } _ => (), }); diff --git a/src/types/function.rs b/src/types/function.rs index 5a9face247d..6ea1cb1cc27 100644 --- a/src/types/function.rs +++ b/src/types/function.rs @@ -1,9 +1,10 @@ -use std::ffi::{CStr, CString}; - -use crate::derive_utils::PyFunctionArguments; use crate::exceptions::PyValueError; use crate::prelude::*; -use crate::{ffi, AsPyPointer}; +use crate::{ + class::methods::{self, PyMethodDef}, + ffi, AsPyPointer, +}; +use crate::{derive_utils::PyFunctionArguments, methods::NulByteInString}; /// Represents a builtin Python function object. #[repr(transparent)] @@ -11,33 +12,23 @@ pub struct PyCFunction(PyAny); pyobject_native_var_type!(PyCFunction, ffi::PyCFunction_Type, ffi::PyCFunction_Check); -fn get_name(name: &str) -> PyResult<&'static CStr> { - let cstr = CString::new(name) - .map_err(|_| PyValueError::new_err("Function name cannot contain contain NULL byte."))?; - Ok(Box::leak(cstr.into_boxed_c_str())) -} - -fn get_doc(doc: &str) -> PyResult<&'static CStr> { - let cstr = CString::new(doc) - .map_err(|_| PyValueError::new_err("Document cannot contain contain NULL byte."))?; - Ok(Box::leak(cstr.into_boxed_c_str())) -} - impl PyCFunction { /// Create a new built-in function with keywords. /// /// See [raw_pycfunction] for documentation on how to get the `fun` argument. pub fn new_with_keywords<'a>( fun: ffi::PyCFunctionWithKeywords, - name: &str, - doc: &str, + name: &'static str, + doc: &'static str, py_or_module: PyFunctionArguments<'a>, ) -> PyResult<&'a Self> { Self::internal_new( - get_name(name)?, - get_doc(doc)?, - unsafe { std::mem::transmute(fun) }, - ffi::METH_VARARGS | ffi::METH_KEYWORDS, + PyMethodDef::cfunction_with_keywords( + name, + methods::PyCFunctionWithKeywords(fun), + 0, + doc, + ), py_or_module, ) } @@ -45,34 +36,25 @@ impl PyCFunction { /// Create a new built-in function without keywords. pub fn new<'a>( fun: ffi::PyCFunction, - name: &str, - doc: &str, + name: &'static str, + doc: &'static str, py_or_module: PyFunctionArguments<'a>, ) -> PyResult<&'a Self> { Self::internal_new( - get_name(name)?, - get_doc(doc)?, - fun, - ffi::METH_NOARGS, + PyMethodDef::cfunction(name, methods::PyCFunction(fun), doc), py_or_module, ) } #[doc(hidden)] - pub fn internal_new<'a>( - name: &'static CStr, - doc: &'static CStr, - method: ffi::PyCFunction, - flags: std::os::raw::c_int, - py_or_module: PyFunctionArguments<'a>, - ) -> PyResult<&'a Self> { + pub fn internal_new( + method_def: PyMethodDef, + py_or_module: PyFunctionArguments, + ) -> PyResult<&Self> { let (py, module) = py_or_module.into_py_and_maybe_module(); - let def = ffi::PyMethodDef { - ml_name: name.as_ptr(), - ml_meth: Some(method), - ml_flags: flags, - ml_doc: doc.as_ptr(), - }; + let def = method_def + .as_method_def() + .map_err(|NulByteInString(err)| PyValueError::new_err(err))?; let (mod_ptr, module_name) = if let Some(m) = module { let mod_ptr = m.as_ptr(); let name = m.name()?.into_py(py);