From 9ae339d1cf1b3aca6fa5b0a5e373e12c54043563 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Fri, 31 Dec 2021 15:11:02 +0200 Subject: [PATCH] verify py method args count --- guide/src/class/protocols.md | 2 +- pyo3-macros-backend/src/defs.rs | 2 +- pyo3-macros-backend/src/pymethod.rs | 10 ++++++---- src/class/number.rs | 14 +++++++++++--- tests/test_arithmetics.rs | 4 ++-- tests/test_arithmetics_protos.rs | 4 ++-- tests/test_compile_error.rs | 1 + tests/ui/invalid_pymethod_proto_args.rs | 13 +++++++++++++ tests/ui/invalid_pymethod_proto_args.stderr | 5 +++++ 9 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 tests/ui/invalid_pymethod_proto_args.rs create mode 100644 tests/ui/invalid_pymethod_proto_args.stderr diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 4bdcaa471ce..57be0c70546 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -288,7 +288,7 @@ This trait also has support the augmented arithmetic assignments (`+=`, `-=`, * `fn __itruediv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` * `fn __ifloordiv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` * `fn __imod__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` - * `fn __ipow__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` + * `fn __ipow__(&'p mut self, other: impl FromPyObject, _modulo: impl FromPyObject) -> PyResult<()>` * `fn __ilshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` * `fn __irshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` * `fn __iand__(&'p mut self, other: impl FromPyObject) -> PyResult<()>` diff --git a/pyo3-macros-backend/src/defs.rs b/pyo3-macros-backend/src/defs.rs index be90f137974..258029d63f7 100644 --- a/pyo3-macros-backend/src/defs.rs +++ b/pyo3-macros-backend/src/defs.rs @@ -421,7 +421,7 @@ pub const NUM: Proto = Proto { .args(&["Other"]) .has_self(), MethodProto::new("__ipow__", "PyNumberIPowProtocol") - .args(&["Other"]) + .args(&["Other", "Modulo"]) .has_self(), MethodProto::new("__ilshift__", "PyNumberILShiftProtocol") .args(&["Other"]) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index c8b4d607e5b..4cc982bde33 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -856,8 +856,11 @@ fn generate_method_body( ) -> Result { let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode); let rust_name = spec.name; - let (arg_idents, conversions) = + let (arg_idents, arg_count, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments, extract_error_mode)?; + if arg_count != arguments.len() { + bail_spanned!(spec.name.span() => format!("Expected {} arguments, got {}", arguments.len(), arg_count)); + } let call = quote! { _pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; let body = if let Some(return_mode) = return_mode { return_mode.return_call_output(py, call) @@ -1026,7 +1029,7 @@ fn extract_proto_arguments( method_args: &[FnArg], proto_args: &[Ty], extract_error_mode: ExtractErrorMode, -) -> Result<(Vec, TokenStream)> { +) -> Result<(Vec, usize, TokenStream)> { let mut arg_idents = Vec::with_capacity(method_args.len()); let mut non_python_args = 0; @@ -1045,9 +1048,8 @@ fn extract_proto_arguments( arg_idents.push(ident); } } - let conversions = quote!(#(#args_conversions)*); - Ok((arg_idents, conversions)) + Ok((arg_idents, non_python_args, conversions)) } struct StaticIdent(&'static str); diff --git a/src/class/number.rs b/src/class/number.rs index 2b0c2ade799..69ee2fcf5c6 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -221,7 +221,7 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __ipow__(&'p mut self, other: Self::Other) -> Self::Result + fn __ipow__(&'p mut self, other: Self::Other, modulo: Option) -> Self::Result where Self: PyNumberIPowProtocol<'p>, { @@ -504,6 +504,7 @@ pub trait PyNumberIDivmodProtocol<'p>: PyNumberProtocol<'p> { pub trait PyNumberIPowProtocol<'p>: PyNumberProtocol<'p> { type Other: FromPyObject<'p>; type Result: IntoPyCallbackOutput<()>; + type Modulo: FromPyObject<'p>; } #[allow(clippy::upper_case_acronyms)] @@ -718,7 +719,7 @@ py_binary_self_func!(imod, PyNumberIModProtocol, T::__imod__); pub unsafe extern "C" fn ipow( slf: *mut ffi::PyObject, other: *mut ffi::PyObject, - _modulo: *mut ffi::PyObject, + modulo: *mut ffi::PyObject, ) -> *mut ffi::PyObject where T: for<'p> PyNumberIPowProtocol<'p>, @@ -728,7 +729,14 @@ where crate::callback_body!(py, { let slf_cell = py.from_borrowed_ptr::>(slf); let other = py.from_borrowed_ptr::(other); - call_operator_mut!(py, slf_cell, __ipow__, other).convert(py)?; + let modulo = py.from_borrowed_ptr::(modulo); + slf_cell + .try_borrow_mut()? + .__ipow__( + extract_or_return_not_implemented!(other), + extract_or_return_not_implemented!(modulo), + ) + .convert(py)?; ffi::Py_INCREF(slf); Ok::<_, PyErr>(slf) }) diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 5d9e2b1a316..67b187e96ce 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -95,7 +95,7 @@ impl InPlaceOperations { self.value |= other; } - fn __ipow__(&mut self, other: u32) { + fn __ipow__(&mut self, other: u32, _modulo: Option) { self.value = self.value.pow(other); } } @@ -566,7 +566,7 @@ mod return_not_implemented { fn __itruediv__(&mut self, _other: PyRef) {} fn __ifloordiv__(&mut self, _other: PyRef) {} fn __imod__(&mut self, _other: PyRef) {} - fn __ipow__(&mut self, _other: PyRef) {} + fn __ipow__(&mut self, _other: PyRef, _modulo: Option) {} fn __ilshift__(&mut self, _other: PyRef) {} fn __irshift__(&mut self, _other: PyRef) {} fn __iand__(&mut self, _other: PyRef) {} diff --git a/tests/test_arithmetics_protos.rs b/tests/test_arithmetics_protos.rs index b57c5f49268..656a9bc27b4 100644 --- a/tests/test_arithmetics_protos.rs +++ b/tests/test_arithmetics_protos.rs @@ -108,7 +108,7 @@ impl PyNumberProtocol for InPlaceOperations { self.value |= other; } - fn __ipow__(&mut self, other: u32) { + fn __ipow__(&mut self, other: u32, _modulo: Option) { self.value = self.value.pow(other); } } @@ -589,7 +589,7 @@ mod return_not_implemented { fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {} fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {} fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ipow__(&'p mut self, _other: PyRef<'p, Self>, _modulo: Option) {} fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {} fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {} fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {} diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index daade52d8c6..9801f97a1a3 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -30,6 +30,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/missing_clone.rs"); t.compile_fail("tests/ui/reject_generics.rs"); t.compile_fail("tests/ui/not_send.rs"); + t.compile_fail("tests/ui/invalid_pymethod_proto_args.rs"); tests_rust_1_49(&t); tests_rust_1_55(&t); diff --git a/tests/ui/invalid_pymethod_proto_args.rs b/tests/ui/invalid_pymethod_proto_args.rs new file mode 100644 index 00000000000..3c3e96ab3b6 --- /dev/null +++ b/tests/ui/invalid_pymethod_proto_args.rs @@ -0,0 +1,13 @@ +use pyo3::prelude::*; + +#[pyclass] +struct MyClass {} + +#[pymethods] +impl MyClass { + fn __truediv__(&self) -> PyResult<()> { + Ok(()) + } +} + +fn main() {} diff --git a/tests/ui/invalid_pymethod_proto_args.stderr b/tests/ui/invalid_pymethod_proto_args.stderr new file mode 100644 index 00000000000..04d77c76308 --- /dev/null +++ b/tests/ui/invalid_pymethod_proto_args.stderr @@ -0,0 +1,5 @@ +error: Expected 1 arguments, got 0 + --> tests/ui/invalid_pymethod_proto_args.rs:8:8 + | +8 | fn __truediv__(&self) -> PyResult<()> { + | ^^^^^^^^^^^