From 1ba20847653e6e6b500c62d939940e47368d445b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Tue, 25 Aug 2020 00:00:12 +0200 Subject: [PATCH] Derive FromPyObject --- CHANGELOG.md | 1 + pyo3-derive-backend/src/frompy.rs | 429 ++++++++++++++++++++++++++++++ pyo3-derive-backend/src/lib.rs | 2 + pyo3cls/src/lib.rs | 14 +- src/prelude.rs | 2 +- tests/test_frompyobject.rs | 287 ++++++++++++++++++++ 6 files changed, 732 insertions(+), 3 deletions(-) create mode 100644 pyo3-derive-backend/src/frompy.rs create mode 100644 tests/test_frompyobject.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 442a2365605..6823c134745 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Implement type information for conversion failures. [#1050](https://github.com/PyO3/pyo3/pull/1050) - Add `PyBytes::new_with` and `PyByteArray::new_with` for initialising Python-allocated bytes and bytearrays using a closure. [#1074](https://github.com/PyO3/pyo3/pull/1074) - Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098) +- Add derivations for `FromPyObject` for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) diff --git a/pyo3-derive-backend/src/frompy.rs b/pyo3-derive-backend/src/frompy.rs new file mode 100644 index 00000000000..fc2a950b780 --- /dev/null +++ b/pyo3-derive-backend/src/frompy.rs @@ -0,0 +1,429 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Paren; +use syn::{ + parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result, + Variant, +}; + +/// Describes derivation input of an enum. +#[derive(Debug)] +struct Enum<'a> { + enum_ident: &'a Ident, + vars: Vec>, +} + +impl<'a> Enum<'a> { + /// Construct a new enum representation. + /// + /// `data_enum` is the `syn` representation of the input enum, `ident` is the + /// `Identifier` of the enum. + fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { + if data_enum.variants.is_empty() { + return Err(syn::Error::new_spanned( + &data_enum.variants, + "Cannot derive FromPyObject for empty enum.", + )); + } + let vars = data_enum + .variants + .iter() + .map(Container::from_variant) + .collect::>>()?; + + Ok(Enum { + enum_ident: ident, + vars, + }) + } + + /// Build derivation body for enums. + fn derive_enum(&self) -> TokenStream { + let mut var_extracts = Vec::new(); + let mut error_names = String::new(); + for (i, var) in self.vars.iter().enumerate() { + let ext = match &var.style { + Style::Struct(tups) => self.build_struct_variant(tups, var.ident), + Style::StructNewtype(ident) => { + self.build_transparent_variant(var.ident, Some(ident)) + } + Style::Tuple(len) => self.build_tuple_variant(var.ident, *len), + Style::TupleNewtype => self.build_transparent_variant(var.ident, None), + }; + var_extracts.push(ext); + error_names.push_str(&var.err_name); + if i < self.vars.len() - 1 { + error_names.push_str(", "); + } + } + quote!( + #(#var_extracts)* + let type_name = obj.get_type().name(); + let from = obj + .repr() + .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) + .unwrap_or_else(|_| type_name.to_string()); + let err_msg = format!("Can't convert {} to {}", from, #error_names); + Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) + ) + } + + /// Build match for tuple struct variant. + fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream { + let enum_ident = self.enum_ident; + let mut ext: Punctuated = Punctuated::new(); + let mut fields: Punctuated = Punctuated::new(); + let mut field_pats = PatTuple { + attrs: vec![], + paren_token: Paren::default(), + elems: Default::default(), + }; + for i in 0..len { + ext.push(parse_quote!(slice[#i].extract())); + let ident = Ident::new(&format!("_field{}", i), Span::call_site()); + field_pats.elems.push(parse_quote!(Ok(#ident))); + fields.push(ident); + } + + quote!( + match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) { + Ok(s) => { + if s.len() == #len { + let slice = s.as_slice(); + if let (#field_pats) = (#ext) { + return Ok(#enum_ident::#var_ident(#fields)) + } + } + }, + Err(_) => {} + } + ) + } + + /// Build match for transparent enum variants. + fn build_transparent_variant( + &self, + var_ident: &Ident, + field_ident: Option<&Ident>, + ) -> TokenStream { + let enum_ident = self.enum_ident; + if let Some(ident) = field_ident { + quote!( + if let Ok(#ident) = obj.extract() { + return Ok(#enum_ident::#var_ident{#ident}) + } + ) + } else { + quote!( + if let Ok(inner) = obj.extract() { + return Ok(#enum_ident::#var_ident(inner)) + } + ) + } + } + + /// Build match for struct variant with named fields. + fn build_struct_variant( + &self, + tups: &[(&'a Ident, ExprCall)], + var_ident: &Ident, + ) -> TokenStream { + let enum_ident = self.enum_ident; + let mut field_pats = PatTuple { + attrs: vec![], + paren_token: Paren::default(), + elems: Default::default(), + }; + let mut fields: Punctuated = Punctuated::new(); + let mut ext: Punctuated = Punctuated::new(); + for (ident, ext_fn) in tups { + field_pats.elems.push(parse_quote!(Ok(#ident))); + fields.push(parse_quote!(#ident)); + ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract()))); + } + quote!(if let #field_pats = #ext { + return Ok(#enum_ident::#var_ident{#fields}); + }) + } +} + +/// Container Style +/// +/// Covers Structs, Tuplestructs and corresponding Newtypes. +#[derive(Clone, Debug)] +enum Style<'a> { + /// Struct Container, e.g. `struct Foo { a: String }` + /// + /// Variant contains the list of field identifiers and the corresponding extraction call. + Struct(Vec<(&'a Ident, ExprCall)>), + /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` + /// + /// The field specified by the identifier is extracted directly from the object. + StructNewtype(&'a Ident), + /// Tuple struct, e.g. `struct Foo(String)`. + /// + /// Fields are extracted from a tuple. + Tuple(usize), + /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` + /// + /// The wrapped field is directly extracted from the object. + TupleNewtype, +} + +/// Data container +/// +/// Either describes a struct or an enum variant. +#[derive(Debug)] +struct Container<'a> { + ident: &'a Ident, + style: Style<'a>, + err_name: String, +} + +impl<'a> Container<'a> { + /// Construct a container from an enum Variant. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn from_variant(var: &'a Variant) -> Result { + Self::new(&var.fields, &var.ident, &var.attrs) + } + + /// Construct a container based on fields, identifier and attributes. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result { + let transparent = attrs.iter().any(|a| a.path.is_ident("transparent")); + if transparent { + Self::check_transparent_len(fields)?; + } + let style = match fields { + Fields::Unnamed(unnamed) => { + if transparent { + Style::TupleNewtype + } else { + Style::Tuple(unnamed.unnamed.len()) + } + } + Fields::Named(named) => { + if transparent { + let field = named + .named + .iter() + .next() + .expect("Check for len 1 is done above"); + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + Style::StructNewtype(ident) + } else { + let mut fields = Vec::new(); + for field in named.named.iter() { + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + fields.push((ident, ext_fn(&field.attrs, ident)?)) + } + Style::Struct(fields) + } + } + Fields::Unit => { + return Err(syn::Error::new_spanned( + &fields, + "Cannot derive FromPyObject for Unit structs and variants", + )) + } + }; + let err_name = maybe_renamed_err(&attrs)? + .map(|s| s.value()) + .unwrap_or_else(|| ident.to_string()); + + let v = Container { + ident: &ident, + style, + err_name, + }; + Ok(v) + } + + /// Build derivation body for a struct. + fn derive_struct(&self) -> TokenStream { + let ext = match &self.style { + Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), + Style::TupleNewtype => self.build_newtype_struct(None), + Style::Tuple(len) => self.build_tuple_struct(*len), + Style::Struct(tups) => self.build_struct(tups), + }; + quote!(Ok(#ext)) + } + + fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { + if let Some(ident) = field_ident { + quote!( + Self{#ident: obj.extract()?} + ) + } else { + quote!(Self(obj.extract()?)) + } + } + + fn build_tuple_struct(&self, len: usize) -> TokenStream { + let mut fields: Punctuated = Punctuated::new(); + for i in 0..len { + fields.push(quote!(slice[#i].extract()?)); + } + quote!( + let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; + let seq_len = s.len(); + if seq_len != #len { + let msg = format!( + "Expected tuple of length {}, but got length {}.", + #len, + seq_len + ); + return Err(::pyo3::exceptions::PyValueError::py_err(msg)) + } + let slice = s.as_slice(); + Self(#fields) + ) + } + + fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream { + let mut fields: Punctuated = Punctuated::new(); + for (ident, ext_fn) in tups { + fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); + } + quote!(Self{#fields}) + } + + fn check_transparent_len(fields: &Fields) -> Result<()> { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + fields, + "Transparent structs and variants can only have 1 field", + )); + } + Ok(()) + } +} + +/// Get the extraction function that's called on the input object. +/// +/// Valid arguments are `get_item`, `get_attr` which are called with the +/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")` +fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result { + let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) { + attr + } else { + return Ok(parse_quote!(getattr(stringify!(#field_ident)))); + }; + if let Ok(ident) = attr.parse_args::() { + if ident != "getattr" && ident != "get_item" { + Err(syn::Error::new_spanned( + ident, + "Only get_item and getattr are valid for extraction.", + )) + } else { + let arg = field_ident.to_string(); + Ok(parse_quote!(#ident(#arg))) + } + } else if let Ok(call) = attr.parse_args() { + Ok(call) + } else { + Err(syn::Error::new_spanned( + attr, + "Only get_item and getattr are valid for extraction,\ + both can be passed with or without an argument, e.g. \ + #[extract(getattr(\"attr\")] and #[extract(getattr)]", + )) + } +} + +/// Returns the name of the variant for the error message if no variants match. +fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result> { + for attr in attrs { + if !attr.path.is_ident("rename_err") { + continue; + } + let attr = attr.parse_meta()?; + if let syn::Meta::NameValue(nv) = &attr { + match &nv.lit { + syn::Lit::Str(s) => { + return Ok(Some(s.clone())); + } + _ => { + return Err(syn::Error::new_spanned( + attr, + "rename_err attribute must be string literal: #[rename_err=\"Name\"]", + )) + } + } + } + } + Ok(None) +} + +fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { + let lifetimes = generics.lifetimes().collect::>(); + if lifetimes.len() > 1 { + return Err(syn::Error::new_spanned( + &generics, + "Only a single lifetime parameter can be specified.", + )); + } + Ok(lifetimes.into_iter().next()) +} + +/// Derive FromPyObject for enums and structs. +/// +/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier +/// * At least one field, in case of `#[transparent]`, exactly one field +/// * At least one variant for enums. +/// * Fields of input structs and enums must implement `FromPyObject` +/// * Derivation for structs with generic fields like `struct Foo(T)` +/// adds `T: FromPyObject` on the derived implementation. +pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result { + let mut trait_generics = tokens.generics.clone(); + let generics = &tokens.generics; + let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { + lt.clone() + } else { + trait_generics.params.push(parse_quote!('source)); + parse_quote!('source) + }; + let mut where_clause: syn::WhereClause = parse_quote!(where); + for param in generics.type_params() { + let gen_ident = ¶m.ident; + where_clause + .predicates + .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) + } + let derives = match &tokens.data { + syn::Data::Enum(en) => { + let en = Enum::new(en, &tokens.ident)?; + en.derive_enum() + } + syn::Data::Struct(st) => { + let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?; + st.derive_struct() + } + _ => { + return Err(syn::Error::new_spanned( + tokens, + "FromPyObject can only be derived for structs and enums.", + )) + } + }; + + let ident = &tokens.ident; + Ok(quote!( + #[automatically_derived] + impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { + fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { + #derives + } + } + )) +} diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 5695adf42e8..78db37368fb 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -4,6 +4,7 @@ #![recursion_limit = "1024"] mod defs; +mod frompy; mod konst; mod method; mod module; @@ -15,6 +16,7 @@ mod pymethod; mod pyproto; mod utils; +pub use frompy::build_derive_from_pyobject; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionAttr}; diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 795423cbff3..2377ec03bbe 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -5,8 +5,8 @@ extern crate proc_macro; use proc_macro::TokenStream; use pyo3_derive_backend::{ - build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc, - process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, + build_derive_from_pyobject, build_py_class, build_py_function, build_py_methods, + build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, }; use quote::quote; use syn::parse_macro_input; @@ -91,3 +91,13 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { ) .into() } + +#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))] +pub fn derive_from_py_object(item: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(item as syn::DeriveInput); + let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error()); + quote!( + #expanded + ) + .into() +} diff --git a/src/prelude.rs b/src/prelude.rs index e3bc5f0b8fe..8046096f7f2 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -20,4 +20,4 @@ pub use crate::{FromPyObject, IntoPy, IntoPyPointer, PyTryFrom, PyTryInto, ToPyO // PyModule is only part of the prelude because we need it for the pymodule function pub use crate::types::{PyAny, PyModule}; #[cfg(feature = "macros")] -pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto}; +pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject}; diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs new file mode 100644 index 00000000000..83a64960ff0 --- /dev/null +++ b/tests/test_frompyobject.rs @@ -0,0 +1,287 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyString, PyTuple}; +use pyo3::{PyErrValue, PyMappingProtocol}; + +#[macro_use] +mod common; + +#[derive(Debug, FromPyObject)] +pub struct A<'a> { + #[extract(getattr)] + s: String, + #[extract(get_item)] + t: &'a PyString, + #[extract(getattr("foo"))] + p: &'a PyAny, +} + +#[pyclass] +pub struct PyA { + #[pyo3(get)] + s: String, + #[pyo3(get)] + foo: Option, +} + +#[pyproto] +impl PyMappingProtocol for PyA { + fn __getitem__(&self, key: String) -> pyo3::PyResult { + if key == "t" { + Ok("bar".into()) + } else { + Err(PyValueError::py_err("Failed")) + } + } +} + +#[test] +fn test_named_fields_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pya = PyA { + s: "foo".into(), + foo: None, + }; + let py_c = Py::new(py, pya).unwrap(); + let a: A = FromPyObject::extract(py_c.as_ref(py)).expect("Failed to extract A from PyA"); + assert_eq!(a.s, "foo"); + assert_eq!(a.t.to_string_lossy(), "bar"); + assert!(a.p.is_none()); +} + +#[derive(Debug, FromPyObject)] +#[transparent] +pub struct B { + test: String, +} + +#[test] +fn test_transparent_named_field_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let test = "test".into_py(py); + let b: B = FromPyObject::extract(test.as_ref(py)).expect("Failed to extract B from String"); + assert_eq!(b.test, "test"); + let test: PyObject = 1.into_py(py); + let b = B::extract(test.as_ref(py)); + assert!(b.is_err()) +} + +#[derive(Debug, FromPyObject)] +#[transparent] +pub struct D { + test: T, +} + +#[test] +fn test_generic_transparent_named_field_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let test = "test".into_py(py); + let d: D = + D::extract(test.as_ref(py)).expect("Failed to extract D from String"); + assert_eq!(d.test, "test"); + let test = 1usize.into_py(py); + let d: D = D::extract(test.as_ref(py)).expect("Failed to extract D from String"); + assert_eq!(d.test, 1); +} + +#[derive(Debug, FromPyObject)] +pub struct E { + test: T, + test2: T2, +} + +#[pyclass] +pub struct PyE { + #[pyo3(get)] + test: String, + #[pyo3(get)] + test2: usize, +} + +#[test] +fn test_generic_named_fields_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pye = PyE { + test: "test".into(), + test2: 2, + } + .into_py(py); + + let e: E = + E::extract(pye.as_ref(py)).expect("Failed to extract E from PyE"); + assert_eq!(e.test, "test"); + assert_eq!(e.test2, 2); + let e = E::::extract(pye.as_ref(py)); + assert!(e.is_err()); +} + +#[derive(Debug, FromPyObject)] +pub struct C { + #[extract(getattr("test"))] + test: String, +} + +#[test] +fn test_named_field_with_ext_fn() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pyc = PyE { + test: "foo".into(), + test2: 0, + } + .into_py(py); + let c = C::extract(pyc.as_ref(py)).expect("Failed to extract C from PyE"); + assert_eq!(c.test, "foo"); +} + +#[derive(Debug, FromPyObject)] +pub enum Foo<'a> { + TupleVar(usize, String), + StructVar { + test: &'a PyString, + }, + #[transparent] + TransparentTuple(usize), + #[transparent] + TransparentStructVar { + a: Option, + }, + StructVarGetAttrArg { + #[extract(getattr("bla"))] + a: bool, + }, + StructWithGetItem { + #[extract(get_item)] + a: String, + }, + StructWithGetItemArg { + #[extract(get_item("foo"))] + a: String, + }, + #[transparent] + CatchAll(&'a PyAny), +} + +#[pyclass] +pub struct PyBool { + #[pyo3(get)] + bla: bool, +} + +#[test] +fn test_enum() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let tup = PyTuple::new(py, &[1.into_py(py), "test".into_py(py)]); + let foo = Foo::extract(tup.as_ref()).expect("Failed to extract Foo from tuple"); + match foo { + Foo::TupleVar(test, test2) => { + assert_eq!(test, 1); + assert_eq!(test2, "test"); + } + _ => panic!("Expected extracting Foo::TupleVar, got {:?}", foo), + } + + let pye = PyE { + test: "foo".into(), + test2: 0, + } + .into_py(py); + let foo = Foo::extract(pye.as_ref(py)).expect("Failed to extract Foo from PyE"); + match foo { + Foo::StructVar { test } => assert_eq!(test.to_string_lossy(), "foo"), + _ => panic!("Expected extracting Foo::StructVar, got {:?}", foo), + } + + let int: PyObject = 1.into_py(py); + let foo = Foo::extract(int.as_ref(py)).expect("Failed to extract Foo from int"); + match foo { + Foo::TransparentTuple(test) => assert_eq!(test, 1), + _ => panic!("Expected extracting Foo::TransparentTuple, got {:?}", foo), + } + let none = py.None(); + let foo = Foo::extract(none.as_ref(py)).expect("Failed to extract Foo from int"); + match foo { + Foo::TransparentStructVar { a } => assert!(a.is_none()), + _ => panic!( + "Expected extracting Foo::TransparentStructVar, got {:?}", + foo + ), + } + + let pybool = PyBool { bla: true }.into_py(py); + let foo = Foo::extract(pybool.as_ref(py)).expect("Failed to extract Foo from PyBool"); + match foo { + Foo::StructVarGetAttrArg { a } => assert!(a), + _ => panic!( + "Expected extracting Foo::StructVarGetAttrArg, got {:?}", + foo + ), + } + + let dict = PyDict::new(py); + dict.set_item("a", "test").expect("Failed to set item"); + let foo = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match foo { + Foo::StructWithGetItem { a } => assert_eq!(a, "test"), + _ => panic!("Expected extracting Foo::StructWithGetItem, got {:?}", foo), + } + + let dict = PyDict::new(py); + dict.set_item("foo", "test").expect("Failed to set item"); + let foo = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match foo { + Foo::StructWithGetItemArg { a } => assert_eq!(a, "test"), + _ => panic!( + "Expected extracting Foo::StructWithGetItemArg, got {:?}", + foo + ), + } + + let dict = PyDict::new(py); + let foo = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match foo { + Foo::CatchAll(any) => { + let d = <&PyDict>::extract(any).expect("Expected pydict"); + assert!(d.is_empty()); + } + _ => panic!("Expected extracting Foo::CatchAll, got {:?}", foo), + } +} + +#[derive(FromPyObject)] +pub enum Bar { + #[rename_err = "str"] + A(String), + #[rename_err = "uint"] + B(usize), + #[rename_err = "int"] + C(isize), +} + +#[test] +fn test_err_rename() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let dict = PyDict::new(py); + let foo = Bar::extract(dict.as_ref()); + assert!(foo.is_err()); + match foo { + Ok(_) => {} + Err(e) => match e.pvalue { + PyErrValue::Value(val) => assert_eq!(String::extract(val.as_ref(py)).expect(""), ""), + PyErrValue::None => {} + PyErrValue::ToArgs(_) => {} + PyErrValue::ToObject(to) => { + let o = to.to_object(py); + let s = String::extract(o.as_ref(py)).expect("Err val is not a string"); + assert_eq!(s, "Can't convert {} (dict) to str, uint, int") + } + }, + } +}