diff --git a/newsfragments/2630.fixed.md b/newsfragments/2630.fixed.md new file mode 100644 index 00000000000..dc0a424f54a --- /dev/null +++ b/newsfragments/2630.fixed.md @@ -0,0 +1 @@ +Fix compile error since 0.17.0 with `Option<&SomePyClass>` argument with a default. diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index e54a6d5ff72..a1c20e1af3f 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -1,7 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::{ - attributes::FromPyWithAttribute, method::{FnArg, FnSpec}, pyfunction::Argument, }; @@ -222,7 +221,8 @@ fn impl_arg_param( _pyo3::impl_::extract_argument::extract_optional_argument( _kwargs.map(::std::convert::AsRef::as_ref), &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, - #name_str + #name_str, + || None )? }); } @@ -230,76 +230,58 @@ fn impl_arg_param( let arg_value = quote_arg_span!(#args_array[#option_pos]); *option_pos += 1; - let tokens = if let Some(FromPyWithAttribute { - value: expr_path, .. - }) = &arg.attrs.from_py_with - { - match (spec.default_value(name), arg.optional.is_some()) { - (Some(default), true) if default.to_string() != "None" => { - quote_arg_span! { - _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || Some(#default))? - } - } - (Some(default), _) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || #default)? - } + let mut default = spec.default_value(name); + + // Option arguments have special treatment: the default should be specified _without_ the + // Some() wrapper. Maybe this should be changed in future?! + if arg.optional.is_some() { + default = Some(match &default { + Some(expression) if expression.to_string() != "None" => { + quote!(::std::option::Option::Some(#expression)) } - (None, true) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || None)? - } + _ => quote!(::std::option::Option::None), + }) + } + + let tokens = if let Some(expr_path) = arg.attrs.from_py_with.as_ref().map(|attr| &attr.value) { + if let Some(default) = default { + quote_arg_span! { + _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || #default)? } - (None, false) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::from_py_with( - _pyo3::impl_::extract_argument::unwrap_required_argument(#arg_value), - #name_str, - #expr_path, - )? - } + } else { + quote_arg_span! { + _pyo3::impl_::extract_argument::from_py_with( + _pyo3::impl_::extract_argument::unwrap_required_argument(#arg_value), + #name_str, + #expr_path, + )? } } + } else if arg.optional.is_some() { + quote_arg_span! { + _pyo3::impl_::extract_argument::extract_optional_argument( + #arg_value, + &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, + #name_str, + || #default + )? + } + } else if let Some(default) = default { + quote_arg_span! { + _pyo3::impl_::extract_argument::extract_argument_with_default( + #arg_value, + &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, + #name_str, + || #default + )? + } } else { - match (spec.default_value(name), arg.optional.is_some()) { - (Some(default), true) if default.to_string() != "None" => { - quote_arg_span! { - _pyo3::impl_::extract_argument::extract_argument_with_default( - #arg_value, - &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, - #name_str, - || Some(#default) - )? - } - } - (Some(default), _) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::extract_argument_with_default( - #arg_value, - &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, - #name_str, - || #default - )? - } - } - (None, true) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::extract_optional_argument( - #arg_value, - &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, - #name_str - )? - } - } - (None, false) => { - quote_arg_span! { - _pyo3::impl_::extract_argument::extract_argument( - _pyo3::impl_::extract_argument::unwrap_required_argument(#arg_value), - &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, - #name_str - )? - } - } + quote_arg_span! { + _pyo3::impl_::extract_argument::extract_argument( + _pyo3::impl_::extract_argument::unwrap_required_argument(#arg_value), + &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, + #name_str + )? } }; Ok(tokens) diff --git a/src/impl_/extract_argument.rs b/src/impl_/extract_argument.rs index f74aa280ce2..1ce60c14a90 100644 --- a/src/impl_/extract_argument.rs +++ b/src/impl_/extract_argument.rs @@ -91,13 +91,14 @@ where } } -/// Alternative to [`extract_argument`] used for `Option` arguments (because they are implicitly treated -/// as optional if at the end of the positional parameters). +/// Alternative to [`extract_argument`] used for `Option` arguments. This is necessary because Option<&T> +/// does not implement `PyFunctionArgument` for `T: PyClass`. #[doc(hidden)] pub fn extract_optional_argument<'a, 'py, T>( obj: Option<&'py PyAny>, holder: &'a mut T::Holder, arg_name: &str, + default: fn() -> Option, ) -> PyResult> where T: PyFunctionArgument<'a, 'py>, @@ -105,12 +106,13 @@ where match obj { Some(obj) => { if obj.is_none() { + // Explicit `None` will result in None being used as the function argument Ok(None) } else { extract_argument(obj, holder, arg_name).map(Some) } } - None => Ok(None), + _ => Ok(default()), } } @@ -132,16 +134,12 @@ where } /// Alternative to [`extract_argument`] used when the argument has a `#[pyo3(from_py_with)]` annotation. -/// -/// # Safety -/// - `obj` must not be None (this helper is only used for required function arguments). #[doc(hidden)] pub fn from_py_with<'py, T>( obj: &'py PyAny, arg_name: &str, extractor: fn(&'py PyAny) -> PyResult, ) -> PyResult { - // Safety: obj is not None (see safety match extractor(obj) { Ok(value) => Ok(value), Err(e) => Err(argument_extraction_error(obj.py(), arg_name, e)), diff --git a/tests/test_methods.rs b/tests/test_methods.rs index d25ea71b7cd..04d2d721d15 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -1029,3 +1029,28 @@ pymethods!( fn issue_1696(&self, _x: &InstanceMethod) {} } ); + +#[test] +fn test_option_pyclass_arg() { + // Option<&PyClass> argument with a default set in a signature regressed to a compile + // error in PyO3 0.17.0 - this test it continues to be accepted. + + #[pyclass] + struct SomePyClass {} + + #[pyfunction(arg = "None")] + fn option_class_arg(arg: Option<&SomePyClass>) -> Option { + arg.map(|_| SomePyClass {}) + } + + Python::with_gil(|py| { + let f = wrap_pyfunction!(option_class_arg, py).unwrap(); + assert!(f.call0().unwrap().is_none()); + let obj = Py::new(py, SomePyClass {}).unwrap(); + assert!(f + .call1((obj,)) + .unwrap() + .extract::>() + .is_ok()); + }) +}