diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index 604fd6b0f50..ffa4dc6d2d8 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -3,7 +3,7 @@ use crate::{ utils::get_pyo3_crate, }; use proc_macro2::TokenStream; -use quote::{quote, format_ident}; +use quote::{format_ident, quote}; use syn::{ parenthesized, parse::{Parse, ParseStream}, @@ -213,18 +213,12 @@ impl<'a> Container<'a> { fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { let self_ty = &self.path; if let Some(ident) = field_ident { - let error_msg = format!( - "failed to extract field {}.{}", - quote!(#self_ty), - quote!(#ident) - ); + let struct_name = quote!(#self_ty).to_string(); + let field_name = ident.to_string(); quote!( - ::std::result::Result::Ok(#self_ty{#ident: obj.extract().map_err(|inner| { - let py = _pyo3::PyNativeType::py(obj); - let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg); - new_err.set_cause(py, ::std::option::Option::Some(inner)); - new_err - })?}) + ::std::result::Result::Ok(#self_ty{ + #ident: _pyo3::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)? + }) ) } else if !self.is_enum_variant { let error_msg = format!("failed to extract inner field of {}", quote!(#self_ty)); @@ -244,33 +238,28 @@ impl<'a> Container<'a> { fn build_tuple_struct(&self, tups: &[FieldPyO3Attributes]) -> TokenStream { let self_ty = &self.path; - let field_idents: Vec<_> = (0..tups.len()).into_iter().map(|i| format_ident!("arg{}", i)).collect(); + let field_idents: Vec<_> = (0..tups.len()) + .into_iter() + .map(|i| format_ident!("arg{}", i)) + .collect(); + let struct_name = "e!(#self_ty).to_string(); let fields = tups.iter().zip(&field_idents).enumerate().map(|(index, (attrs, ident))| { - let error_msg = format!("failed to extract field {}.{}", quote!(#self_ty), index); - - let parsed_item = match &attrs.from_py_with { + match &attrs.from_py_with { None => quote!( - _pyo3::PyAny::extract(#ident) + _pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)? ), Some(FromPyWithAttribute { value: expr_path, .. }) => quote! ( - #expr_path(#ident) + _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)? ), - }; - - quote!( - #parsed_item.map_err(|inner| { - let py = _pyo3::PyNativeType::py(obj); - let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg); - new_err.set_cause(py, ::std::option::Option::Some(inner)); - new_err - })? - ) + } }); quote!( - let (#(#field_idents),*) = obj.extract()?; - ::std::result::Result::Ok(#self_ty(#(#fields),*)) + match obj.extract() { + ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)), + ::std::result::Result::Err(err) => ::std::result::Result::Err(err), + } ) } @@ -278,36 +267,27 @@ impl<'a> Container<'a> { let self_ty = &self.path; let mut fields: Punctuated = Punctuated::new(); for (ident, attrs) in tups { + let struct_name = quote!(#self_ty).to_string(); + let field_name = ident.to_string(); let getter = match &attrs.getter { - FieldGetter::GetAttr(Some(name)) => quote!(getattr(_pyo3::intern!(py, #name))), + FieldGetter::GetAttr(Some(name)) => { + quote!(getattr(_pyo3::intern!(obj.py(), #name))) + } FieldGetter::GetAttr(None) => { - quote!(getattr(_pyo3::intern!(py, stringify!(#ident)))) + quote!(getattr(_pyo3::intern!(obj.py(), #field_name))) } FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)), - FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))), + FieldGetter::GetItem(None) => quote!(get_item(#field_name)), }; - let conversion_error_msg = - format!("failed to extract field {}.{}", quote!(#self_ty), ident); - let get_field = quote!(obj.#getter?); let extractor = match &attrs.from_py_with { - None => quote!({ - let py = _pyo3::PyNativeType::py(obj); - #get_field.extract().map_err(|inner| { - let new_err = _pyo3::exceptions::PyTypeError::new_err(#conversion_error_msg); - new_err.set_cause(py, ::std::option::Option::Some(inner)); - new_err - })? - }), + None => { + quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?) + } Some(FromPyWithAttribute { value: expr_path, .. - }) => quote! ( - #expr_path(#get_field).map_err(|inner| { - let py = _pyo3::PyNativeType::py(obj); - let new_err = _pyo3::exceptions::PyTypeError::new_err(#conversion_error_msg); - new_err.set_cause(py, ::std::option::Option::Some(inner)); - new_err - })? - ), + }) => { + quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?) + } }; fields.push(quote!(#ident: #extractor)); diff --git a/src/impl_/frompyobject.rs b/src/impl_/frompyobject.rs index 8c0db79833a..3d68ba55507 100644 --- a/src/impl_/frompyobject.rs +++ b/src/impl_/frompyobject.rs @@ -1,4 +1,4 @@ -use crate::{exceptions::PyTypeError, PyErr, Python}; +use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python}; #[cold] pub fn failed_to_extract_enum( @@ -24,3 +24,103 @@ pub fn failed_to_extract_enum( } PyTypeError::new_err(err_msg) } + +pub fn extract_struct_field<'py, T>( + obj: &'py PyAny, + struct_name: &str, + field_name: &str, +) -> PyResult +where + T: FromPyObject<'py>, +{ + match obj.extract() { + ok @ Ok(_) => ok, + Err(err) => Err(failed_to_extract_struct_field( + obj.py(), + err, + struct_name, + field_name, + )), + } +} + +pub fn extract_struct_field_with<'py, T>( + extractor: impl FnOnce(&'py PyAny) -> PyResult, + obj: &'py PyAny, + struct_name: &str, + field_name: &str, +) -> PyResult { + match extractor(obj) { + ok @ Ok(_) => ok, + Err(err) => Err(failed_to_extract_struct_field( + obj.py(), + err, + struct_name, + field_name, + )), + } +} + +#[cold] +fn failed_to_extract_struct_field( + py: Python<'_>, + inner_err: PyErr, + struct_name: &str, + field_name: &str, +) -> PyErr { + let new_err = PyTypeError::new_err(format!( + "failed to extract field {}.{}", + struct_name, field_name + )); + new_err.set_cause(py, ::std::option::Option::Some(inner_err)); + new_err +} + +pub fn extract_tuple_struct_field<'py, T>( + obj: &'py PyAny, + struct_name: &str, + index: usize, +) -> PyResult +where + T: FromPyObject<'py>, +{ + match obj.extract() { + ok @ Ok(_) => ok, + Err(err) => Err(failed_to_extract_tuple_struct_field( + obj.py(), + err, + struct_name, + index, + )), + } +} + +pub fn extract_tuple_struct_field_with<'py, T>( + extractor: impl FnOnce(&'py PyAny) -> PyResult, + obj: &'py PyAny, + struct_name: &str, + index: usize, +) -> PyResult { + match extractor(obj) { + ok @ Ok(_) => ok, + Err(err) => Err(failed_to_extract_tuple_struct_field( + obj.py(), + err, + struct_name, + index, + )), + } +} + +#[cold] +fn failed_to_extract_tuple_struct_field( + py: Python<'_>, + inner_err: PyErr, + struct_name: &str, + index: usize, +) -> PyErr { + let new_err = + PyTypeError::new_err(format!("failed to extract field {}.{}", struct_name, index)); + new_err.set_cause(py, ::std::option::Option::Some(inner_err)); + new_err +}