Skip to content

Commit

Permalink
Add compile fail tests for FromPyObject derives + some fixes.
Browse files Browse the repository at this point in the history
Fix some error messages and accidental passes.
  • Loading branch information
sebpuetz committed Aug 30, 2020
1 parent 7781bb7 commit a8c5379
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 51 deletions.
122 changes: 71 additions & 51 deletions pyo3-derive-backend/src/from_pyobject.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, ToTokens};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, Result};

/// Describes derivation input of an enum.
Expand All @@ -17,8 +18,8 @@ impl<'a> Enum<'a> {
/// `Identifier` of the enum.
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
if data_enum.variants.is_empty() {
return Err(syn::Error::new_spanned(
&data_enum.variants,
return Err(spanned_err(
&ident,
"Cannot derive FromPyObject for empty enum.",
));
}
Expand Down Expand Up @@ -121,6 +122,12 @@ impl<'a> Container<'a> {
attrs: Vec<ContainerAttribute>,
is_enum_variant: bool,
) -> Result<Self> {
if fields.is_empty() {
return Err(spanned_err(
fields,
"Cannot derive FromPyObject for empty structs and variants.",
));
}
let transparent = attrs.iter().any(ContainerAttribute::transparent);
if transparent {
Self::check_transparent_len(fields)?;
Expand Down Expand Up @@ -154,10 +161,11 @@ impl<'a> Container<'a> {
ContainerType::Struct(fields)
}
(Fields::Unit, _) => {
return Err(syn::Error::new_spanned(
// covered by length check above
return Err(spanned_err(
&fields,
"Cannot derive FromPyObject for Unit structs and variants",
))
));
}
};
let err_name = attrs
Expand All @@ -175,16 +183,30 @@ impl<'a> Container<'a> {
Ok(v)
}

fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> {
fn verify_struct_container_attrs(
attrs: &'a [ContainerAttribute],
original: &[Attribute],
) -> Result<()> {
for attr in attrs {
match attr {
ContainerAttribute::Transparent => continue,
ContainerAttribute::ErrorAnnotation(_) => {
let span = original
.iter()
.map(|a| a.span())
.fold(None, |mut acc: Option<Span>, span| {
if let Some(all) = acc.as_mut() {
all.join(span)
} else {
Some(span)
}
})
.unwrap_or_else(Span::call_site);
return Err(syn::Error::new(
Span::call_site(),
span,
"Annotating error messages for structs is \
not supported. Remove the annotation attribute.",
))
));
}
}
}
Expand Down Expand Up @@ -264,7 +286,7 @@ impl<'a> Container<'a> {

fn check_transparent_len(fields: &Fields) -> Result<()> {
if fields.len() != 1 {
return Err(syn::Error::new_spanned(
return Err(spanned_err(
fields,
"Transparent structs and variants can only have 1 field",
));
Expand Down Expand Up @@ -315,16 +337,18 @@ impl ContainerAttribute {
if let syn::Lit::Str(s) = &nv.lit {
attrs.push(ContainerAttribute::ErrorAnnotation(s.value()))
} else {
return Err(syn::Error::new_spanned(
&nv.lit,
"Expected string literal.",
));
return Err(spanned_err(&nv.lit, "Expected string literal."));
}
}
_ => (),
other => {
return Err(spanned_err(
other,
"Expected `transparent` or `annotation = \"name\"`",
))
}
}
} else {
return Err(syn::Error::new_spanned(
return Err(spanned_err(
meta,
"Unknown container attribute, expected `transparent` or \
`annotation(\"err_name\")`",
Expand Down Expand Up @@ -352,8 +376,8 @@ impl FieldAttribute {
return Ok(None);
}
if list.nested.len() > 1 {
return Err(syn::Error::new_spanned(
list,
return Err(spanned_err(
list.nested,
"Only one of `item`, `attribute` can be provided, possibly with an \
additional argument: `item(\"key\")` or `attribute(\"name\").",
));
Expand All @@ -362,7 +386,7 @@ impl FieldAttribute {
let meta = match metaitem {
syn::NestedMeta::Meta(meta) => meta,
syn::NestedMeta::Lit(lit) => {
return Err(syn::Error::new_spanned(
return Err(spanned_err(
lit,
"Expected `attribute` or `item`, not a literal.",
))
Expand All @@ -374,10 +398,7 @@ impl FieldAttribute {
} else if path.is_ident("item") {
Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?)))
} else {
Err(syn::Error::new_spanned(
meta,
"Expected `attribute` or `item`.",
))
Err(spanned_err(meta, "Expected `attribute` or `item`."))
}
}

Expand All @@ -387,52 +408,56 @@ impl FieldAttribute {
syn::Meta::Path(_) => return Ok(None),
Meta::NameValue(nv) => {
let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`";
return Err(syn::Error::new_spanned(nv, err_msg));
return Err(spanned_err(nv, err_msg));
}
};
if arg_list.nested.len() != 1 {
return Err(syn::Error::new_spanned(
arg_list,
"Expected a single string literal.",
));
let arg_msg = "Expected a single string literal argument.";
if arg_list.nested.is_empty() {
return Err(spanned_err(arg_list, arg_msg));
} else if arg_list.nested.len() > 1 {
return Err(spanned_err(arg_list.nested, arg_msg));
}
let first = arg_list.nested.first().unwrap();
if let syn::NestedMeta::Lit(lit) = first {
if let syn::Lit::Str(litstr) = lit {
if litstr.value().is_empty() {
return Err(spanned_err(litstr, "Attribute name cannot be empty."));
}
return Ok(Some(parse_quote!(#litstr)));
}
}
Err(syn::Error::new_spanned(
first,
"Expected a single string literal.",
))
Err(spanned_err(first, arg_msg))
}

fn item_arg(meta: syn::Meta) -> syn::Result<Option<syn::Lit>> {
let arg_list = match meta {
syn::Meta::List(list) => list,
syn::Meta::Path(_) => return Ok(None),
Meta::NameValue(nv) => {
return Err(syn::Error::new_spanned(
return Err(spanned_err(
nv,
"Expected a literal or no argument: `pyo3(item(\"key\") or `pyo3(item)`",
))
}
};
if arg_list.nested.len() != 1 {
return Err(syn::Error::new_spanned(
arg_list,
"Expected a single literal.",
));
let arg_msg = "Expected a single literal argument.";
if arg_list.nested.is_empty() {
return Err(spanned_err(arg_list, arg_msg));
} else if arg_list.nested.len() > 1 {
return Err(spanned_err(arg_list.nested, arg_msg));
}
let first = arg_list.nested.first().unwrap();
if let syn::NestedMeta::Lit(lit) = first {
return Ok(Some(parse_quote!(#lit)));
}
Err(syn::Error::new_spanned(first, "Expected a literal."))
Err(spanned_err(first, arg_msg))
}
}

fn spanned_err<T: ToTokens>(tokens: T, msg: &str) -> syn::Error {
syn::Error::new_spanned(tokens, msg)
}

/// Extract pyo3 metalist, flattens multiple lists into a single one.
fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<syn::MetaList> {
let mut list: Punctuated<syn::NestedMeta, syn::Token![,]> = Punctuated::new();
Expand All @@ -443,12 +468,7 @@ fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<syn::MetaList> {
list.push(meta);
}
}
_ => {
return Err(syn::Error::new_spanned(
value,
"Expected `pyo3()` attribute.",
))
}
_ => continue,
}
}
Ok(syn::MetaList {
Expand All @@ -461,9 +481,9 @@ fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<syn::MetaList> {
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
if lifetimes.len() > 1 {
return Err(syn::Error::new_spanned(
return Err(spanned_err(
&generics,
"FromPyObject can only be derived with at most one lifetime parameter.",
"FromPyObject can be derived with at most one lifetime parameter.",
));
}
Ok(lifetimes.into_iter().next())
Expand Down Expand Up @@ -500,15 +520,15 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
}
syn::Data::Struct(st) => {
let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?;
Container::verify_struct_container_attrs(&attrs)?;
Container::verify_struct_container_attrs(&attrs, &tokens.attrs)?;
let ident = &tokens.ident;
let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?;
st.build()
}
_ => {
return Err(syn::Error::new_spanned(
syn::Data::Union(_) => {
return Err(spanned_err(
tokens,
"FromPyObject can only be derived for structs and enums.",
"FromPyObject can not be derived for unions.",
))
}
};
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#[test]
fn test_compile_errors() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/invalid_frompy_derive.rs");
t.compile_fail("tests/ui/invalid_macro_args.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
Expand Down

0 comments on commit a8c5379

Please sign in to comment.