diff --git a/CHANGELOG.md b/CHANGELOG.md index 73b1244ea6a..b06bff83a06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fix visibility of `PyDictItems`, `PyDictKeys`, and `PyDictValues` types added in PyO3 0.17.0. +- Fix compile failure when using `#[pyo3(from_py_with = "...")]` attribute on an argument of type `Option`. [#2592](https://github.com/PyO3/pyo3/pull/2592) ## [0.17.0] - 2022-08-23 diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index c6d7c0a8492..c64de72895e 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -247,7 +247,7 @@ fn impl_arg_param( } (None, true) => { quote_arg_span! { - _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || Some(None))? + _pyo3::impl_::extract_argument::from_py_with_with_default(#arg_value, #name_str, #expr_path, || None)? } } (None, false) => { diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 49dba14b61f..f42c6f7e922 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -165,6 +165,42 @@ fn test_function_with_custom_conversion_error() { }); } +#[test] +fn test_from_py_with_defaults() { + fn optional_int(x: &PyAny) -> PyResult> { + if x.is_none() { + Ok(None) + } else { + Some(x.extract()).transpose() + } + } + + // issue 2280 combination of from_py_with and Option did not compile + #[pyfunction] + fn from_py_with_option(#[pyo3(from_py_with = "optional_int")] int: Option) -> i32 { + int.unwrap_or(0) + } + + #[pyfunction(len = "0")] + fn from_py_with_default(#[pyo3(from_py_with = "PyAny::len")] len: usize) -> usize { + len + } + + Python::with_gil(|py| { + let f = wrap_pyfunction!(from_py_with_option)(py).unwrap(); + + assert_eq!(f.call0().unwrap().extract::().unwrap(), 0); + assert_eq!(f.call1((123,)).unwrap().extract::().unwrap(), 123); + assert_eq!(f.call1((999,)).unwrap().extract::().unwrap(), 999); + + let f2 = wrap_pyfunction!(from_py_with_default)(py).unwrap(); + + assert_eq!(f2.call0().unwrap().extract::().unwrap(), 0); + assert_eq!(f2.call1(("123",)).unwrap().extract::().unwrap(), 3); + assert_eq!(f2.call1(("1234",)).unwrap().extract::().unwrap(), 4); + }); +} + #[pyclass] #[derive(Debug, FromPyObject)] struct ValueClass {