Skip to content

Commit

Permalink
Initial attempt- messy but passes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dessix committed Aug 2, 2021
1 parent 031fea6 commit aee3e31
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 5 deletions.
28 changes: 27 additions & 1 deletion impl/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, ToTokens};
use std::iter::FromIterator;
use syn::parse::{Nothing, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitInt, LitStr,
Result, Token,
Result, Token, TypeParamBound,
};

pub struct Attrs<'a> {
Expand All @@ -13,6 +14,7 @@ pub struct Attrs<'a> {
pub backtrace: Option<&'a Attribute>,
pub from: Option<&'a Attribute>,
pub transparent: Option<Transparent<'a>>,
pub bound: Option<Bound<'a>>,
}

#[derive(Clone)]
Expand All @@ -29,13 +31,20 @@ pub struct Transparent<'a> {
pub span: Span,
}

#[derive(Clone)]
pub struct Bound<'a> {
pub original: &'a Attribute,
pub bounds: Punctuated<TypeParamBound, token::Add>,
}

pub fn get(input: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs {
display: None,
source: None,
backtrace: None,
from: None,
transparent: None,
bound: None,
};

for attr in input {
Expand Down Expand Up @@ -70,6 +79,7 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {

fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
syn::custom_keyword!(transparent);
syn::custom_keyword!(bound);

attr.parse_args_with(|input: ParseStream| {
if let Some(kw) = input.parse::<Option<transparent>>()? {
Expand All @@ -84,6 +94,22 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu
span: kw.span,
});
return Ok(());
} else if input.parse::<Option<bound>>()?.is_some() {
if attrs.bound.is_some() {
return Err(Error::new_spanned(
attr,
"duplicate #[error(bound)] attribute",
));
}
input.parse::<token::Eq>().map_err(|_| {
Error::new_spanned(attr, "\"bound\" keyword must be followed by '='")
})?;
let bound = Bound {
original: attr,
bounds: Punctuated::<TypeParamBound, token::Add>::parse_separated_nonempty(input)?,
};
attrs.bound = Some(bound);
return Ok(());
}

let display = Display {
Expand Down
112 changes: 108 additions & 4 deletions impl/src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
use syn::{
token, Data, DeriveInput, Member, PathArguments, Result, Type, Visibility, WherePredicate,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand Down Expand Up @@ -112,9 +114,27 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let display_impl_generics = {
let mut lifetime_params = input.generics.lifetimes().peekable();
let mut type_params = input.generics.type_params().peekable();
if let Some(bounds) = &input.attrs.bound {
let bounds = std::iter::repeat(bounds).map(|x| &x.bounds);
quote! {
<#(#type_params: #bounds),*>
}
} else if lifetime_params.peek().is_none() || type_params.peek().is_none() {
quote! {
<#(#lifetime_params),* #(#type_params),*>
}
} else {
quote! {
<#(#lifetime_params),* , #(#type_params),*>
}
}
};
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #display_impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
#[allow(
// Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
clippy::nonstandard_macro_braces,
Expand Down Expand Up @@ -143,6 +163,17 @@ fn impl_struct(input: Struct) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);
let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);

quote! {
#[allow(unused_qualifications)]
Expand Down Expand Up @@ -302,9 +333,27 @@ fn impl_enum(input: Enum) -> TokenStream {
#ty::#ident #pat => #display
}
});
let display_impl_generics = {
let mut lifetime_params = input.generics.lifetimes().peekable();
let mut type_params = input.generics.type_params().peekable();
if let Some(bounds) = &input.attrs.bound {
let bounds = std::iter::repeat(bounds).map(|x| &x.bounds);
quote! {
<#(#type_params: #bounds),*>
}
} else if lifetime_params.peek().is_none() || type_params.peek().is_none() {
quote! {
<#(#lifetime_params),* #(#type_params),*>
}
} else {
quote! {
<#(#lifetime_params),* , #(#type_params),*>
}
}
};
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #display_impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
#[allow(
Expand Down Expand Up @@ -342,7 +391,17 @@ fn impl_enum(input: Enum) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);

let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);
quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
Expand Down Expand Up @@ -424,3 +483,48 @@ fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
let error = quote_spanned!(last_span=> Error);
quote!(#path #error)
}

/// Enhance a where clause with the given predicates, or create one with them if needed.
/// When no new predicates are provided, return without alteration.
fn extend_where_clause<TPredicates: std::iter::ExactSizeIterator<Item = syn::WherePredicate>>(
// Clause to be enhanced; created if absent when predicates are provided
where_clause: Option<syn::WhereClause>,
// Used to create span for new where clause if populating
where_span: proc_macro2::Span,
predicates: TPredicates,
) -> Option<syn::WhereClause> {
// If we don't have any predicates to add, it doesn't matter if we
// have a where clause to extend or not; return whatever was given
if predicates.len() == 0 {
return where_clause;
}
Some(match where_clause {
// Extend the existing clause with the new predicates
Some(mut where_clause) => {
where_clause.predicates.extend(predicates);
where_clause
}
// No where clause provided; create a new one with the provided span
None => syn::WhereClause {
where_token: token::Where(where_span),
predicates: predicates.collect(),
},
})
}

fn apply_type_bounds<'a, TTypeParams: std::iter::Iterator<Item = &'a syn::TypeParam>>(
type_params: TTypeParams,
bound_attr: &'a crate::attr::Bound<'_>,
) -> Vec<WherePredicate> {
let bounds = &bound_attr.bounds;
if bounds.is_empty() {
return Vec::new();
}
type_params
.map(move |p| {
let predicate = quote! { #p: #bounds };
syn::parse2::<syn::WherePredicate>(predicate)
.expect("quasiquote must create predicate bounds")
})
.collect()
}
50 changes: 50 additions & 0 deletions tests/test_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,56 @@ enum EnumError {
Unit,
}

// TODO: This scenario should fail because no generics are present to be bounded; add ui test for it
// #[derive(Error, Debug)]
// #[error(bound = std::error::Error + 'static)]
// enum WithoutBoundNoGeneric {
// #[error(transparent)]
// Variant(u32),
// }

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
enum WithGeneric<T> {
Variant,
Generic(T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericFrom<T> {
Variant,
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericTransparent<T> {
#[error("variant")]
Variant,
#[error(transparent)]
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
#[error(transparent)]
struct WithGenericStruct<T> {
#[from] inner: T,
}

impl<T: Display> Display for WithGeneric<T> {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
}

impl<T: Display> Display for WithGenericFrom<T> {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
}

unimplemented_display!(BracedError);
unimplemented_display!(TupleError);
unimplemented_display!(UnitError);
Expand Down

0 comments on commit aee3e31

Please sign in to comment.