diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e3a9e4e56e..68897e3deb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `PyAny::contains` method (`in` operator for `PyAny`). [#2115](https://github.com/PyO3/pyo3/pull/2115) - Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133) - Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159) +- Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181) ### Changed diff --git a/guide/src/conversions/traits.md b/guide/src/conversions/traits.md index 4de449e365b..63fa81d2e70 100644 --- a/guide/src/conversions/traits.md +++ b/guide/src/conversions/traits.md @@ -442,6 +442,10 @@ If the input is neither a string nor an integer, the error message will be: - `pyo3(item)`, `pyo3(item("key"))` - retrieve the field from a mapping, possibly with the custom key specified as an argument. - can be any literal that implements `ToBorrowedObject` +- `pyo3(from_py_with = "...")` + - apply a custom function to convert the field from Python the desired Rust type. + - the argument must be the name of the function as a string. + - the function signature must be `fn(&PyAny) -> PyResult` where `T` is the Rust type of the argument. ### `IntoPy` diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index ef72c069de7..f262df3ecbd 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -106,8 +106,9 @@ enum ContainerType<'a> { StructNewtype(&'a Ident), /// Tuple struct, e.g. `struct Foo(String)`. /// - /// Fields are extracted from a tuple. - Tuple(usize), + /// Variant contains a list of conversion methods for each of the fields that are directly + /// extracted from the tuple. + Tuple(Vec), /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` /// /// The wrapped field is directly extracted from the object. @@ -149,7 +150,15 @@ impl<'a> Container<'a> { (Fields::Unnamed(_), true) => ContainerType::TupleNewtype, (Fields::Unnamed(unnamed), false) => match unnamed.unnamed.len() { 1 => ContainerType::TupleNewtype, - len => ContainerType::Tuple(len), + _ => { + let fields = unnamed + .unnamed + .iter() + .map(|field| FieldPyO3Attributes::from_attrs(&field.attrs)) + .collect::>>()?; + + ContainerType::Tuple(fields) + } }, (Fields::Named(named), true) => { let field = named @@ -196,7 +205,7 @@ impl<'a> Container<'a> { match &self.ty { ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(ident)), ContainerType::TupleNewtype => self.build_newtype_struct(None), - ContainerType::Tuple(len) => self.build_tuple_struct(*len), + ContainerType::Tuple(tups) => self.build_tuple_struct(tups), ContainerType::Struct(tups) => self.build_struct(tups), } } @@ -233,19 +242,33 @@ impl<'a> Container<'a> { } } - fn build_tuple_struct(&self, len: usize) -> TokenStream { + fn build_tuple_struct(&self, tups: &[FieldPyO3Attributes]) -> TokenStream { let self_ty = &self.path; let mut fields: Punctuated = Punctuated::new(); - for i in 0..len { - let error_msg = format!("failed to extract field {}.{}", quote!(#self_ty), i); - fields.push(quote!( - s.get_item(#i).and_then(_pyo3::types::PyAny::extract).map_err(|inner| { - let py = _pyo3::PyNativeType::py(obj); - let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg); - new_err.set_cause(py, ::std::option::Option::Some(inner)); - new_err - })?)); + for (index, attrs) in tups.iter().enumerate() { + let error_msg = format!("failed to extract field {}.{}", quote!(#self_ty), index); + + let parsed_item = match &attrs.from_py_with { + None => quote!( + obj.get_item(#index)?.extract() + ), + Some(FromPyWithAttribute(expr_path)) => quote! ( + #expr_path(obj.get_item(#index)?) + ), + }; + + let extractor = quote!( + #parsed_item.map_err(|inner| { + let py = _pyo3::PyNativeType::py(obj); + let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg); + new_err.set_cause(py, ::std::option::Option::Some(inner)); + new_err + })? + ); + + fields.push(quote!(#extractor)); } + let len = tups.len(); let msg = if self.is_enum_variant { quote!(::std::format!( "expected tuple of length {}, but got length {}", diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index ff0085a8fc3..17dc234b58a 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -71,7 +71,7 @@ impl PyClassArgs { } } - /// Adda single expression from the comma separated list in the attribute, which is + /// Add a single expression from the comma separated list in the attribute, which is /// either a single word or an assignment expression fn add_expr(&mut self, expr: &Expr) -> Result<()> { match expr { diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index 17428cedde0..41557f36312 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -481,3 +481,40 @@ fn test_from_py_with() { assert_eq!(zap.some_object_length, 3usize); }); } + +#[derive(Debug, FromPyObject)] +pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAny::len")] usize); + +#[test] +fn test_from_py_with_tuple_struct() { + Python::with_gil(|py| { + let py_zap = py + .eval(r#"("whatever", [1, 2, 3])"#, None, None) + .expect("failed to create tuple"); + + let zap = ZapTuple::extract(py_zap).unwrap(); + + assert_eq!(zap.0, "whatever"); + assert_eq!(zap.1, 3usize); + }); +} + +#[derive(Debug, FromPyObject, PartialEq)] +pub enum ZapEnum { + Zip(#[pyo3(from_py_with = "PyAny::len")] usize), + Zap(String, #[pyo3(from_py_with = "PyAny::len")] usize), +} + +#[test] +fn test_from_py_with_enum() { + Python::with_gil(|py| { + let py_zap = py + .eval(r#"("whatever", [1, 2, 3])"#, None, None) + .expect("failed to create tuple"); + + let zap = ZapEnum::extract(py_zap).unwrap(); + let expected_zap = ZapEnum::Zap(String::from("whatever"), 3usize); + + assert_eq!(zap, expected_zap); + }); +}