Skip to content

Commit

Permalink
frompyobject: tidy up generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jun 2, 2022
1 parent cfb9105 commit e4ec720
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 53 deletions.
84 changes: 32 additions & 52 deletions pyo3-macros-backend/src/frompyobject.rs
Expand Up @@ -3,7 +3,7 @@ use crate::{
utils::get_pyo3_crate,
};
use proc_macro2::TokenStream;
use quote::{quote, format_ident};
use quote::{format_ident, quote};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
Expand Down Expand Up @@ -213,18 +213,12 @@ impl<'a> Container<'a> {
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
let self_ty = &self.path;
if let Some(ident) = field_ident {
let error_msg = format!(
"failed to extract field {}.{}",
quote!(#self_ty),
quote!(#ident)
);
let struct_name = quote!(#self_ty).to_string();
let field_name = ident.to_string();
quote!(
::std::result::Result::Ok(#self_ty{#ident: obj.extract().map_err(|inner| {
let py = _pyo3::PyNativeType::py(obj);
let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg);
new_err.set_cause(py, ::std::option::Option::Some(inner));
new_err
})?})
::std::result::Result::Ok(#self_ty{
#ident: _pyo3::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
})
)
} else if !self.is_enum_variant {
let error_msg = format!("failed to extract inner field of {}", quote!(#self_ty));
Expand All @@ -244,70 +238,56 @@ impl<'a> Container<'a> {

fn build_tuple_struct(&self, tups: &[FieldPyO3Attributes]) -> TokenStream {
let self_ty = &self.path;
let field_idents: Vec<_> = (0..tups.len()).into_iter().map(|i| format_ident!("arg{}", i)).collect();
let field_idents: Vec<_> = (0..tups.len())
.into_iter()
.map(|i| format_ident!("arg{}", i))
.collect();
let struct_name = &quote!(#self_ty).to_string();
let fields = tups.iter().zip(&field_idents).enumerate().map(|(index, (attrs, ident))| {
let error_msg = format!("failed to extract field {}.{}", quote!(#self_ty), index);

let parsed_item = match &attrs.from_py_with {
match &attrs.from_py_with {
None => quote!(
_pyo3::PyAny::extract(#ident)
_pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)?
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
#expr_path(#ident)
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)?
),
};

quote!(
#parsed_item.map_err(|inner| {
let py = _pyo3::PyNativeType::py(obj);
let new_err = _pyo3::exceptions::PyTypeError::new_err(#error_msg);
new_err.set_cause(py, ::std::option::Option::Some(inner));
new_err
})?
)
}
});
quote!(
let (#(#field_idents),*) = obj.extract()?;
::std::result::Result::Ok(#self_ty(#(#fields),*))
match obj.extract() {
::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
::std::result::Result::Err(err) => ::std::result::Result::Err(err),
}
)
}

fn build_struct(&self, tups: &[(&Ident, FieldPyO3Attributes)]) -> TokenStream {
let self_ty = &self.path;
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for (ident, attrs) in tups {
let struct_name = quote!(#self_ty).to_string();
let field_name = ident.to_string();
let getter = match &attrs.getter {
FieldGetter::GetAttr(Some(name)) => quote!(getattr(_pyo3::intern!(py, #name))),
FieldGetter::GetAttr(Some(name)) => {
quote!(getattr(_pyo3::intern!(obj.py(), #name)))
}
FieldGetter::GetAttr(None) => {
quote!(getattr(_pyo3::intern!(py, stringify!(#ident))))
quote!(getattr(_pyo3::intern!(obj.py(), #field_name)))
}
FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))),
FieldGetter::GetItem(None) => quote!(get_item(#field_name)),
};
let conversion_error_msg =
format!("failed to extract field {}.{}", quote!(#self_ty), ident);
let get_field = quote!(obj.#getter?);
let extractor = match &attrs.from_py_with {
None => quote!({
let py = _pyo3::PyNativeType::py(obj);
#get_field.extract().map_err(|inner| {
let new_err = _pyo3::exceptions::PyTypeError::new_err(#conversion_error_msg);
new_err.set_cause(py, ::std::option::Option::Some(inner));
new_err
})?
}),
None => {
quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
#expr_path(#get_field).map_err(|inner| {
let py = _pyo3::PyNativeType::py(obj);
let new_err = _pyo3::exceptions::PyTypeError::new_err(#conversion_error_msg);
new_err.set_cause(py, ::std::option::Option::Some(inner));
new_err
})?
),
}) => {
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?)
}
};

fields.push(quote!(#ident: #extractor));
Expand Down
102 changes: 101 additions & 1 deletion src/impl_/frompyobject.rs
@@ -1,4 +1,4 @@
use crate::{exceptions::PyTypeError, PyErr, Python};
use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python};

#[cold]
pub fn failed_to_extract_enum(
Expand All @@ -24,3 +24,103 @@ pub fn failed_to_extract_enum(
}
PyTypeError::new_err(err_msg)
}

pub fn extract_struct_field<'py, T>(
obj: &'py PyAny,
struct_name: &str,
field_name: &str,
) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match obj.extract() {
ok @ Ok(_) => ok,
Err(err) => Err(failed_to_extract_struct_field(
obj.py(),
err,
struct_name,
field_name,
)),
}
}

pub fn extract_struct_field_with<'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>,
obj: &'py PyAny,
struct_name: &str,
field_name: &str,
) -> PyResult<T> {
match extractor(obj) {
ok @ Ok(_) => ok,
Err(err) => Err(failed_to_extract_struct_field(
obj.py(),
err,
struct_name,
field_name,
)),
}
}

#[cold]
fn failed_to_extract_struct_field(
py: Python<'_>,
inner_err: PyErr,
struct_name: &str,
field_name: &str,
) -> PyErr {
let new_err = PyTypeError::new_err(format!(
"failed to extract field {}.{}",
struct_name, field_name
));
new_err.set_cause(py, ::std::option::Option::Some(inner_err));
new_err
}

pub fn extract_tuple_struct_field<'py, T>(
obj: &'py PyAny,
struct_name: &str,
index: usize,
) -> PyResult<T>
where
T: FromPyObject<'py>,
{
match obj.extract() {
ok @ Ok(_) => ok,
Err(err) => Err(failed_to_extract_tuple_struct_field(
obj.py(),
err,
struct_name,
index,
)),
}
}

pub fn extract_tuple_struct_field_with<'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>,
obj: &'py PyAny,
struct_name: &str,
index: usize,
) -> PyResult<T> {
match extractor(obj) {
ok @ Ok(_) => ok,
Err(err) => Err(failed_to_extract_tuple_struct_field(
obj.py(),
err,
struct_name,
index,
)),
}
}

#[cold]
fn failed_to_extract_tuple_struct_field(
py: Python<'_>,
inner_err: PyErr,
struct_name: &str,
index: usize,
) -> PyErr {
let new_err =
PyTypeError::new_err(format!("failed to extract field {}.{}", struct_name, index));
new_err.set_cause(py, ::std::option::Option::Some(inner_err));
new_err
}

0 comments on commit e4ec720

Please sign in to comment.