Skip to content

Commit

Permalink
verify py method args count
Browse files Browse the repository at this point in the history
  • Loading branch information
aviramha committed Jan 2, 2022
1 parent 2503a2d commit 9ae339d
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion guide/src/class/protocols.md
Expand Up @@ -288,7 +288,7 @@ This trait also has support the augmented arithmetic assignments (`+=`, `-=`,
* `fn __itruediv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ifloordiv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imod__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject, _modulo: impl FromPyObject) -> PyResult<()>`
* `fn __ilshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __irshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __iand__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/defs.rs
Expand Up @@ -421,7 +421,7 @@ pub const NUM: Proto = Proto {
.args(&["Other"])
.has_self(),
MethodProto::new("__ipow__", "PyNumberIPowProtocol")
.args(&["Other"])
.args(&["Other", "Modulo"])
.has_self(),
MethodProto::new("__ilshift__", "PyNumberILShiftProtocol")
.args(&["Other"])
Expand Down
10 changes: 6 additions & 4 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -856,8 +856,11 @@ fn generate_method_body(
) -> Result<TokenStream> {
let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode);
let rust_name = spec.name;
let (arg_idents, conversions) =
let (arg_idents, arg_count, conversions) =
extract_proto_arguments(cls, py, &spec.args, arguments, extract_error_mode)?;
if arg_count != arguments.len() {
bail_spanned!(spec.name.span() => format!("Expected {} arguments, got {}", arguments.len(), arg_count));
}
let call = quote! { _pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) };
let body = if let Some(return_mode) = return_mode {
return_mode.return_call_output(py, call)
Expand Down Expand Up @@ -1026,7 +1029,7 @@ fn extract_proto_arguments(
method_args: &[FnArg],
proto_args: &[Ty],
extract_error_mode: ExtractErrorMode,
) -> Result<(Vec<Ident>, TokenStream)> {
) -> Result<(Vec<Ident>, usize, TokenStream)> {
let mut arg_idents = Vec::with_capacity(method_args.len());
let mut non_python_args = 0;

Expand All @@ -1045,9 +1048,8 @@ fn extract_proto_arguments(
arg_idents.push(ident);
}
}

let conversions = quote!(#(#args_conversions)*);
Ok((arg_idents, conversions))
Ok((arg_idents, non_python_args, conversions))
}

struct StaticIdent(&'static str);
Expand Down
14 changes: 11 additions & 3 deletions src/class/number.rs
Expand Up @@ -221,7 +221,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __ipow__(&'p mut self, other: Self::Other) -> Self::Result
fn __ipow__(&'p mut self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberIPowProtocol<'p>,
{
Expand Down Expand Up @@ -504,6 +504,7 @@ pub trait PyNumberIDivmodProtocol<'p>: PyNumberProtocol<'p> {
pub trait PyNumberIPowProtocol<'p>: PyNumberProtocol<'p> {
type Other: FromPyObject<'p>;
type Result: IntoPyCallbackOutput<()>;
type Modulo: FromPyObject<'p>;
}

#[allow(clippy::upper_case_acronyms)]
Expand Down Expand Up @@ -718,7 +719,7 @@ py_binary_self_func!(imod, PyNumberIModProtocol, T::__imod__);
pub unsafe extern "C" fn ipow<T>(
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
_modulo: *mut ffi::PyObject,
modulo: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
T: for<'p> PyNumberIPowProtocol<'p>,
Expand All @@ -728,7 +729,14 @@ where
crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let other = py.from_borrowed_ptr::<crate::PyAny>(other);
call_operator_mut!(py, slf_cell, __ipow__, other).convert(py)?;
let modulo = py.from_borrowed_ptr::<crate::PyAny>(modulo);
slf_cell
.try_borrow_mut()?
.__ipow__(
extract_or_return_not_implemented!(other),
extract_or_return_not_implemented!(modulo),
)
.convert(py)?;
ffi::Py_INCREF(slf);
Ok::<_, PyErr>(slf)
})
Expand Down
4 changes: 2 additions & 2 deletions tests/test_arithmetics.rs
Expand Up @@ -95,7 +95,7 @@ impl InPlaceOperations {
self.value |= other;
}

fn __ipow__(&mut self, other: u32) {
fn __ipow__(&mut self, other: u32, _modulo: Option<u32>) {
self.value = self.value.pow(other);
}
}
Expand Down Expand Up @@ -566,7 +566,7 @@ mod return_not_implemented {
fn __itruediv__(&mut self, _other: PyRef<Self>) {}
fn __ifloordiv__(&mut self, _other: PyRef<Self>) {}
fn __imod__(&mut self, _other: PyRef<Self>) {}
fn __ipow__(&mut self, _other: PyRef<Self>) {}
fn __ipow__(&mut self, _other: PyRef<Self>, _modulo: Option<u8>) {}
fn __ilshift__(&mut self, _other: PyRef<Self>) {}
fn __irshift__(&mut self, _other: PyRef<Self>) {}
fn __iand__(&mut self, _other: PyRef<Self>) {}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_arithmetics_protos.rs
Expand Up @@ -108,7 +108,7 @@ impl PyNumberProtocol for InPlaceOperations {
self.value |= other;
}

fn __ipow__(&mut self, other: u32) {
fn __ipow__(&mut self, other: u32, _modulo: Option<u32>) {
self.value = self.value.pow(other);
}
}
Expand Down Expand Up @@ -589,7 +589,7 @@ mod return_not_implemented {
fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>, _modulo: Option<u8>) {}
fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {}
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Expand Up @@ -30,6 +30,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/missing_clone.rs");
t.compile_fail("tests/ui/reject_generics.rs");
t.compile_fail("tests/ui/not_send.rs");
t.compile_fail("tests/ui/invalid_pymethod_proto_args.rs");

tests_rust_1_49(&t);
tests_rust_1_55(&t);
Expand Down
13 changes: 13 additions & 0 deletions tests/ui/invalid_pymethod_proto_args.rs
@@ -0,0 +1,13 @@
use pyo3::prelude::*;

#[pyclass]
struct MyClass {}

#[pymethods]
impl MyClass {
fn __truediv__(&self) -> PyResult<()> {
Ok(())
}
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/ui/invalid_pymethod_proto_args.stderr
@@ -0,0 +1,5 @@
error: Expected 1 arguments, got 0
--> tests/ui/invalid_pymethod_proto_args.rs:8:8
|
8 | fn __truediv__(&self) -> PyResult<()> {
| ^^^^^^^^^^^

0 comments on commit 9ae339d

Please sign in to comment.