Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle #[pyo3(from_py_with = ...)] on dunder (__magic__) methods #4117

Merged
merged 2 commits into from May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/4117.fixed.md
@@ -0,0 +1 @@
Correctly handle `#[pyo3(from_py_with = ...)]` attribute on dunder (`__magic__`) method arguments instead of silently ignoring it.
34 changes: 30 additions & 4 deletions pyo3-macros-backend/src/params.rs
Expand Up @@ -10,7 +10,7 @@ use syn::spanned::Spanned;

pub struct Holders {
holders: Vec<syn::Ident>,
gil_refs_checkers: Vec<syn::Ident>,
gil_refs_checkers: Vec<GilRefChecker>,
}

impl Holders {
Expand All @@ -32,14 +32,28 @@ impl Holders {
&format!("gil_refs_checker_{}", self.gil_refs_checkers.len()),
span,
);
self.gil_refs_checkers.push(gil_refs_checker.clone());
self.gil_refs_checkers
.push(GilRefChecker::FunctionArg(gil_refs_checker.clone()));
gil_refs_checker
}

pub fn push_from_py_with_checker(&mut self, span: Span) -> syn::Ident {
let gil_refs_checker = syn::Ident::new(
&format!("gil_refs_checker_{}", self.gil_refs_checkers.len()),
span,
);
self.gil_refs_checkers
.push(GilRefChecker::FromPyWith(gil_refs_checker.clone()));
gil_refs_checker
}

pub fn init_holders(&self, ctx: &Ctx) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let holders = &self.holders;
let gil_refs_checkers = &self.gil_refs_checkers;
let gil_refs_checkers = self.gil_refs_checkers.iter().map(|checker| match checker {
GilRefChecker::FunctionArg(ident) => ident,
GilRefChecker::FromPyWith(ident) => ident,
});
quote! {
#[allow(clippy::let_unit_value)]
#(let mut #holders = #pyo3_path::impl_::extract_argument::FunctionArgumentHolder::INIT;)*
Expand All @@ -50,11 +64,23 @@ impl Holders {
pub fn check_gil_refs(&self) -> TokenStream {
self.gil_refs_checkers
.iter()
.map(|e| quote_spanned! { e.span() => #e.function_arg(); })
.map(|checker| match checker {
GilRefChecker::FunctionArg(ident) => {
quote_spanned! { ident.span() => #ident.function_arg(); }
}
GilRefChecker::FromPyWith(ident) => {
quote_spanned! { ident.span() => #ident.from_py_with_arg(); }
}
})
.collect()
}
}

enum GilRefChecker {
FunctionArg(syn::Ident),
FromPyWith(syn::Ident),
}

/// Return true if the argument list is simply (*args, **kwds).
pub fn is_forwarded_args(signature: &FunctionSignature<'_>) -> bool {
matches!(
Expand Down
42 changes: 25 additions & 17 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -1053,44 +1053,39 @@ impl Ty {
ctx: &Ctx,
) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let name_str = arg.name().unraw().to_string();
match self {
Ty::Object => extract_object(
extract_error_mode,
holders,
&name_str,
arg,
quote! { #ident },
arg.ty().span(),
ctx
),
Ty::MaybeNullObject => extract_object(
extract_error_mode,
holders,
&name_str,
arg,
quote! {
if #ident.is_null() {
#pyo3_path::ffi::Py_None()
} else {
#ident
}
},
arg.ty().span(),
ctx
),
Ty::NonNullObject => extract_object(
extract_error_mode,
holders,
&name_str,
arg,
quote! { #ident.as_ptr() },
arg.ty().span(),
ctx
),
Ty::IPowModulo => extract_object(
extract_error_mode,
holders,
&name_str,
arg,
quote! { #ident.as_ptr() },
arg.ty().span(),
ctx
),
Ty::CompareOp => extract_error_mode.handle_error(
Expand Down Expand Up @@ -1118,24 +1113,37 @@ impl Ty {
fn extract_object(
extract_error_mode: ExtractErrorMode,
holders: &mut Holders,
name: &str,
arg: &FnArg<'_>,
source_ptr: TokenStream,
span: Span,
ctx: &Ctx,
) -> TokenStream {
let Ctx { pyo3_path } = ctx;
let holder = holders.push_holder(Span::call_site());
let gil_refs_checker = holders.push_gil_refs_checker(span);
let extracted = extract_error_mode.handle_error(
let gil_refs_checker = holders.push_gil_refs_checker(arg.ty().span());
let name = arg.name().unraw().to_string();

let extract = if let Some(from_py_with) =
arg.from_py_with().map(|from_py_with| &from_py_with.value)
{
let from_py_with_checker = holders.push_from_py_with_checker(from_py_with.span());
quote! {
#pyo3_path::impl_::extract_argument::from_py_with(
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0,
#name,
#pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &#from_py_with_checker) as fn(_) -> _,
)
}
} else {
let holder = holders.push_holder(Span::call_site());
quote! {
#pyo3_path::impl_::extract_argument::extract_argument(
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0,
&mut #holder,
#name
)
},
ctx,
);
}
};

let extracted = extract_error_mode.handle_error(extract, ctx);
quote! {
#pyo3_path::impl_::deprecations::inspect_type(#extracted, &#gil_refs_checker)
}
Expand Down
11 changes: 11 additions & 0 deletions tests/test_class_basics.rs
Expand Up @@ -290,6 +290,10 @@ fn get_length(obj: &Bound<'_, PyAny>) -> PyResult<usize> {
Ok(length)
}

fn is_even(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
obj.extract::<i32>().map(|i| i % 2 == 0)
}

#[pyclass]
struct ClassWithFromPyWithMethods {}

Expand Down Expand Up @@ -319,6 +323,10 @@ impl ClassWithFromPyWithMethods {
fn staticmethod(#[pyo3(from_py_with = "get_length")] argument: usize) -> usize {
argument
}

fn __contains__(&self, #[pyo3(from_py_with = "is_even")] obj: bool) -> bool {
obj
}
}

#[test]
Expand All @@ -339,6 +347,9 @@ fn test_pymethods_from_py_with() {
if has_gil_refs:
assert instance.classmethod_gil_ref(arg) == 2
assert instance.staticmethod(arg) == 2

assert 42 in instance
assert 73 not in instance
"#
);
})
Expand Down
8 changes: 8 additions & 0 deletions tests/ui/deprecations.rs
Expand Up @@ -38,6 +38,14 @@ impl MyClass {

#[setter]
fn set_bar_bound(&self, _value: &Bound<'_, PyAny>) {}

fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool {
true
}

fn __contains__(&self, #[pyo3(from_py_with = "extract_bound")] _value: i32) -> bool {
true
}
}

fn main() {}
Expand Down
50 changes: 28 additions & 22 deletions tests/ui/deprecations.stderr
Expand Up @@ -16,6 +16,12 @@ error: use of deprecated struct `pyo3::PyCell`: `PyCell` was merged into `Bound`
23 | fn method_gil_ref(_slf: &PyCell<Self>) {}
| ^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:42:44
|
42 | fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool {
| ^^^^^^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:18:33
|
Expand Down Expand Up @@ -47,69 +53,69 @@ error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`
| ^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:53:43
--> tests/ui/deprecations.rs:61:43
|
53 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
61 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
| ^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:63:19
--> tests/ui/deprecations.rs:71:19
|
63 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
71 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
| ^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:69:57
--> tests/ui/deprecations.rs:77:57
|
69 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
77 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
| ^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:102:27
--> tests/ui/deprecations.rs:110:27
|
102 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32,
110 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32,
| ^^^^^^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:108:29
--> tests/ui/deprecations.rs:116:29
|
108 | fn pyfunction_gil_ref(_any: &PyAny) {}
116 | fn pyfunction_gil_ref(_any: &PyAny) {}
| ^

error: use of deprecated method `pyo3::deprecations::OptionGilRefs::<std::option::Option<T>>::function_arg`: use `Option<&Bound<'_, T>>` instead for this function argument
--> tests/ui/deprecations.rs:111:36
--> tests/ui/deprecations.rs:119:36
|
111 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {}
119 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {}
| ^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:118:27
--> tests/ui/deprecations.rs:126:27
|
118 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))]
126 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))]
| ^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:128:27
--> tests/ui/deprecations.rs:136:27
|
128 | #[pyo3(from_py_with = "PyAny::len")] usize,
136 | #[pyo3(from_py_with = "PyAny::len")] usize,
| ^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:134:31
--> tests/ui/deprecations.rs:142:31
|
134 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32),
142 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32),
| ^^^^^^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
--> tests/ui/deprecations.rs:141:27
--> tests/ui/deprecations.rs:149:27
|
141 | #[pyo3(from_py_with = "extract_gil_ref")]
149 | #[pyo3(from_py_with = "extract_gil_ref")]
| ^^^^^^^^^^^^^^^^^

error: use of deprecated method `pyo3::deprecations::GilRefs::<pyo3::Python<'_>>::is_python`: use `wrap_pyfunction_bound!` instead
--> tests/ui/deprecations.rs:154:13
--> tests/ui/deprecations.rs:162:13
|
154 | let _ = wrap_pyfunction!(double, py);
162 | let _ = wrap_pyfunction!(double, py);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: this error originates in the macro `wrap_pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)