Skip to content

Commit

Permalink
verify py method args count
Browse files Browse the repository at this point in the history
  • Loading branch information
aviramha committed Jan 2, 2022
1 parent 2503a2d commit 8f271fd
Show file tree
Hide file tree
Showing 12 changed files with 711 additions and 8 deletions.
2 changes: 1 addition & 1 deletion guide/src/class/protocols.md
Expand Up @@ -288,7 +288,7 @@ This trait also has support the augmented arithmetic assignments (`+=`, `-=`,
* `fn __itruediv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ifloordiv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imod__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject, _modulo: impl FromPyObject) -> PyResult<()>`
* `fn __ilshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __irshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __iand__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
Expand Down
3 changes: 3 additions & 0 deletions pyo3-macros-backend/Cargo.toml
Expand Up @@ -18,6 +18,9 @@ quote = { version = "1", default-features = false }
proc-macro2 = { version = "1", default-features = false }
pyo3-build-config = { path = "../pyo3-build-config", version = "0.15.1", features = ["resolve-config"] }

[build-dependencies]
pyo3-build-config = { path = "../pyo3-build-config", version = "0.15.1", features = ["resolve-config"] }

[dependencies.syn]
version = "1"
default-features = false
Expand Down
14 changes: 14 additions & 0 deletions pyo3-macros-backend/build.rs
@@ -0,0 +1,14 @@
use pyo3_build_config::pyo3_build_script_impl::{errors::Result, resolve_interpreter_config};

fn configure_pyo3() -> Result<()> {
let interpreter_config = resolve_interpreter_config()?;
interpreter_config.emit_pyo3_cfgs();
Ok(())
}

fn main() {
if let Err(e) = configure_pyo3() {
eprintln!("error: {}", e.report());
std::process::exit(1)
}
}
6 changes: 6 additions & 0 deletions pyo3-macros-backend/src/defs.rs
Expand Up @@ -420,6 +420,12 @@ pub const NUM: Proto = Proto {
MethodProto::new("__imod__", "PyNumberIModProtocol")
.args(&["Other"])
.has_self(),
// See https://bugs.python.org/issue36379
#[cfg(Py_3_8)]
MethodProto::new("__ipow__", "PyNumberIPowProtocol")
.args(&["Other", "Modulo"])
.has_self(),
#[cfg(not(Py_3_8))]
MethodProto::new("__ipow__", "PyNumberIPowProtocol")
.args(&["Other"])
.has_self(),
Expand Down
16 changes: 12 additions & 4 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -532,10 +532,16 @@ const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc")
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
#[cfg(Py_3_8)]
const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc")
.arguments(&[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
#[cfg(not(Py_3_8))]
const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "binaryfunc")
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc")
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
Expand Down Expand Up @@ -856,8 +862,11 @@ fn generate_method_body(
) -> Result<TokenStream> {
let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode);
let rust_name = spec.name;
let (arg_idents, conversions) =
let (arg_idents, arg_count, conversions) =
extract_proto_arguments(cls, py, &spec.args, arguments, extract_error_mode)?;
if arg_count != arguments.len() {
bail_spanned!(spec.name.span() => format!("Expected {} arguments, got {}", arguments.len(), arg_count));
}
let call = quote! { _pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) };
let body = if let Some(return_mode) = return_mode {
return_mode.return_call_output(py, call)
Expand Down Expand Up @@ -1026,7 +1035,7 @@ fn extract_proto_arguments(
method_args: &[FnArg],
proto_args: &[Ty],
extract_error_mode: ExtractErrorMode,
) -> Result<(Vec<Ident>, TokenStream)> {
) -> Result<(Vec<Ident>, usize, TokenStream)> {
let mut arg_idents = Vec::with_capacity(method_args.len());
let mut non_python_args = 0;

Expand All @@ -1045,9 +1054,8 @@ fn extract_proto_arguments(
arg_idents.push(ident);
}
}

let conversions = quote!(#(#args_conversions)*);
Ok((arg_idents, conversions))
Ok((arg_idents, non_python_args, conversions))
}

struct StaticIdent(&'static str);
Expand Down
44 changes: 42 additions & 2 deletions src/class/number.rs
Expand Up @@ -221,6 +221,14 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
#[cfg(Py_3_8)]
fn __ipow__(&'p mut self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberIPowProtocol<'p>,
{
unimplemented!()
}
#[cfg(not(Py_3_8))]
fn __ipow__(&'p mut self, other: Self::Other) -> Self::Result
where
Self: PyNumberIPowProtocol<'p>,
Expand Down Expand Up @@ -504,6 +512,8 @@ pub trait PyNumberIDivmodProtocol<'p>: PyNumberProtocol<'p> {
pub trait PyNumberIPowProtocol<'p>: PyNumberProtocol<'p> {
type Other: FromPyObject<'p>;
type Result: IntoPyCallbackOutput<()>;
#[cfg(Py_3_8)]
type Modulo: FromPyObject<'p>;
}

#[allow(clippy::upper_case_acronyms)]
Expand Down Expand Up @@ -714,11 +724,13 @@ py_binary_self_func!(isub, PyNumberISubProtocol, T::__isub__);
py_binary_self_func!(imul, PyNumberIMulProtocol, T::__imul__);
py_binary_self_func!(imod, PyNumberIModProtocol, T::__imod__);

// See https://bugs.python.org/issue36379
#[doc(hidden)]
#[cfg(Py_3_8)]
pub unsafe extern "C" fn ipow<T>(
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
_modulo: *mut ffi::PyObject,
modulo: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
T: for<'p> PyNumberIPowProtocol<'p>,
Expand All @@ -728,7 +740,35 @@ where
crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let other = py.from_borrowed_ptr::<crate::PyAny>(other);
call_operator_mut!(py, slf_cell, __ipow__, other).convert(py)?;
let modulo = py.from_borrowed_ptr::<crate::PyAny>(modulo);
slf_cell
.try_borrow_mut()?
.__ipow__(
extract_or_return_not_implemented!(other),
extract_or_return_not_implemented!(modulo),
)
.convert(py)?;
ffi::Py_INCREF(slf);
Ok::<_, PyErr>(slf)
})
}

#[doc(hidden)]
#[cfg(not(Py_3_8))]
pub unsafe extern "C" fn ipow<T>(
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
T: for<'p> PyNumberIPowProtocol<'p>,
{
crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let other = py.from_borrowed_ptr::<crate::PyAny>(other);
slf_cell
.try_borrow_mut()?
.__ipow__(extract_or_return_not_implemented!(other))
.convert(py)?;
ffi::Py_INCREF(slf);
Ok::<_, PyErr>(slf)
})
Expand Down

0 comments on commit 8f271fd

Please sign in to comment.