Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bound attribute for generics support #143

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 27 additions & 1 deletion impl/src/attr.rs
Expand Up @@ -2,9 +2,10 @@ use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, ToTokens};
use std::iter::FromIterator;
use syn::parse::{Nothing, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitInt, LitStr,
Result, Token,
Result, Token, TypeParamBound,
};

pub struct Attrs<'a> {
Expand All @@ -13,6 +14,7 @@ pub struct Attrs<'a> {
pub backtrace: Option<&'a Attribute>,
pub from: Option<&'a Attribute>,
pub transparent: Option<Transparent<'a>>,
pub bound: Option<Bound<'a>>,
}

#[derive(Clone)]
Expand All @@ -29,13 +31,20 @@ pub struct Transparent<'a> {
pub span: Span,
}

#[derive(Clone)]
pub struct Bound<'a> {
pub original: &'a Attribute,
pub bounds: Punctuated<TypeParamBound, token::Add>,
}

pub fn get(input: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs {
display: None,
source: None,
backtrace: None,
from: None,
transparent: None,
bound: None,
};

for attr in input {
Expand Down Expand Up @@ -70,6 +79,7 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {

fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
syn::custom_keyword!(transparent);
syn::custom_keyword!(bound);

attr.parse_args_with(|input: ParseStream| {
if let Some(kw) = input.parse::<Option<transparent>>()? {
Expand All @@ -84,6 +94,22 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu
span: kw.span,
});
return Ok(());
} else if input.parse::<Option<bound>>()?.is_some() {
if attrs.bound.is_some() {
return Err(Error::new_spanned(
attr,
"duplicate #[error(bound)] attribute",
));
}
input.parse::<token::Eq>().map_err(|_| {
Error::new_spanned(attr, "\"bound\" keyword must be followed by '='")
})?;
let bound = Bound {
original: attr,
bounds: Punctuated::<TypeParamBound, token::Add>::parse_separated_nonempty(input)?,
};
attrs.bound = Some(bound);
return Ok(());
}

let display = Display {
Expand Down
112 changes: 103 additions & 9 deletions impl/src/expand.rs
Expand Up @@ -2,7 +2,9 @@ use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
use syn::{
token, Data, DeriveInput, Member, PathArguments, Result, Type, Visibility, WherePredicate,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand Down Expand Up @@ -111,9 +113,26 @@ fn impl_struct(input: Struct) -> TokenStream {
} else {
None
};

let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);

let display_impl = display_body.map(|body| {
quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
#[allow(
// Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
Expand All @@ -132,7 +151,11 @@ fn impl_struct(input: Struct) -> TokenStream {
let from = from_field.ty;
let body = from_initializer(from_field, backtrace_field);
quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
#[allow(deprecated)]
fn from(source: #from) -> Self {
Expand All @@ -143,9 +166,12 @@ fn impl_struct(input: Struct) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);

quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
#source_method
#backtrace_method
Expand Down Expand Up @@ -266,6 +292,18 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};

let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);

let display_impl = if input.has_display() {
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
Expand Down Expand Up @@ -303,7 +341,11 @@ fn impl_enum(input: Enum) -> TokenStream {
}
});
Some(quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
Expand Down Expand Up @@ -331,7 +373,11 @@ fn impl_enum(input: Enum) -> TokenStream {
let from = from_field.ty;
let body = from_initializer(from_field, backtrace_field);
Some(quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
#[allow(deprecated)]
fn from(source: #from) -> Self {
Expand All @@ -342,9 +388,12 @@ fn impl_enum(input: Enum) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);

quote! {
#[allow(unused_qualifications)]
#[allow(
unused_qualifications,
// Since we don't merge bounds that cover the same type, suppress this issue
clippy::type_repetition_in_bounds,
)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
#source_method
#backtrace_method
Expand Down Expand Up @@ -424,3 +473,48 @@ fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
let error = quote_spanned!(last_span=> Error);
quote!(#path #error)
}

/// Enhance a where clause with the given predicates, or create one with them if needed.
/// When no new predicates are provided, return without alteration.
fn extend_where_clause<TPredicates: std::iter::ExactSizeIterator<Item = syn::WherePredicate>>(
// Clause to be enhanced; created if absent when predicates are provided
where_clause: Option<syn::WhereClause>,
// Used to create span for new where clause if populating
where_span: proc_macro2::Span,
predicates: TPredicates,
) -> Option<syn::WhereClause> {
// If we don't have any predicates to add, it doesn't matter if we
// have a where clause to extend or not; return whatever was given
if predicates.len() == 0 {
return where_clause;
}
Some(match where_clause {
// Extend the existing clause with the new predicates
Some(mut where_clause) => {
where_clause.predicates.extend(predicates);
where_clause
}
// No where clause provided; create a new one with the provided span
None => syn::WhereClause {
where_token: token::Where(where_span),
predicates: predicates.collect(),
},
})
}

fn apply_type_bounds<'a, TTypeParams: std::iter::Iterator<Item = &'a syn::TypeParam>>(
type_params: TTypeParams,
bound_attr: &'a crate::attr::Bound<'_>,
) -> Vec<WherePredicate> {
let bounds = &bound_attr.bounds;
if bounds.is_empty() {
return Vec::new();
}
type_params
.map(move |syn::TypeParam { ident: tparam, .. }| {
let predicate = quote! { #tparam: #bounds };
syn::parse2::<syn::WherePredicate>(predicate)
.expect("quasiquote must create predicate bounds")
})
.collect()
}
24 changes: 24 additions & 0 deletions impl/src/valid.rs
Expand Up @@ -30,6 +30,18 @@ impl Struct<'_> {
));
}
}
if let Some(crate::attr::Bound {
original: bound_span,
..
}) = self.attrs.bound
{
if self.generics.params.is_empty() {
return Err(Error::new_spanned(
bound_span,
"#[error(bound = ...)] requires at least one generic type parameter",
));
}
}
check_field_attrs(&self.fields)?;
for field in &self.fields {
field.validate()?;
Expand All @@ -52,6 +64,18 @@ impl Enum<'_> {
));
}
}
if let Some(crate::attr::Bound {
original: bound_span,
..
}) = self.attrs.bound
{
if self.generics.params.is_empty() {
return Err(Error::new_spanned(
bound_span,
"#[error(bound = ...)] requires at least one generic type parameter",
));
}
}
let mut from_types = Set::new();
for variant in &self.variants {
if let Some(from_field) = variant.from_field() {
Expand Down
79 changes: 79 additions & 0 deletions tests/test_error.rs
Expand Up @@ -6,6 +6,20 @@ use std::io;
use thiserror::Error;

macro_rules! unimplemented_display {
($($tl:lifetime),*; $tp:tt; $ty:ty) => {
impl<$($tl),*, $tp> Display for $ty {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
}
};
($tp:tt; $ty:ty) => {
impl<$tp> Display for $ty {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
}
};
($ty:ty) => {
impl Display for $ty {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -49,9 +63,74 @@ enum EnumError {
Unit,
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
enum WithGeneric<T> {
Variant,
Generic(T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericFrom<T> {
Variant,
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericTransparent<T> {
#[error("variant")]
Variant,
#[error(transparent)]
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
struct WithGenericStruct<T> {
#[from]
inner: T,
}

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
struct WithGenericStructRef<'a, T> {
inner: &'a WithGenericStruct<T>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is WithGenericStructRef based on a real world use case? Or do your real world uses pretty much all look like one of the ones above, where the type parameter is used as the whole field type?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have at least one of this cases occurring in my current codebase.

I am, however, seeing some difficulties wherein I'd like to be able to "skip" bounds on generics that only exist due to a lack of universal type qualifiers on content fields in enums. That, or some way of specifying bounds on a per-generic-parameter basis.

Copy link
Author

@Dessix Dessix Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you have a suggested solution here, I've fixed the other issue and believe that, while some patterns may be inconvenient, this change does help quite a bit in my code. It may be a stepping stone toward a more robust solution?

}

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'a)]
struct WithGenericStructRefNonStaticInline<'a, T: 'a> {
inner: &'a WithGenericStruct<T>,
}

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'a)]
struct WithGenericStructRefNonStaticWhere<'a, T>
where
T: 'a,
{
inner: &'a WithGenericStruct<T>,
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
#[error(transparent)]
struct WithGenericStructTransparent<T> {
#[from]
inner: T,
}

unimplemented_display!(BracedError);
unimplemented_display!(TupleError);
unimplemented_display!(UnitError);
unimplemented_display!(WithSource);
unimplemented_display!(WithAnyhow);
unimplemented_display!(EnumError);
unimplemented_display!(T; WithGeneric<T>);
unimplemented_display!(T; WithGenericFrom<T>);
unimplemented_display!(T; WithGenericStruct<T>);
unimplemented_display!('a; T; WithGenericStructRef<'a, T>);
unimplemented_display!('a; T; WithGenericStructRefNonStaticInline<'a, T>);
unimplemented_display!('a; T; WithGenericStructRefNonStaticWhere<'a, T>);
15 changes: 15 additions & 0 deletions tests/ui/bound-enum-without-generic.rs
@@ -0,0 +1,15 @@
use thiserror::Error;

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
enum BoundsWithoutGeneric {
Variant(u32),
}

impl std::fmt::Display for BoundsWithoutGeneric {
fn fmt(&self, _formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
unimplemented!()
}
}

fn main() {}