diff --git a/impl/src/ast.rs b/impl/src/ast.rs index 8698ecf..2aa7246 100644 --- a/impl/src/ast.rs +++ b/impl/src/ast.rs @@ -1,4 +1,5 @@ use crate::attr::{self, Attrs}; +use crate::generics::ParamsInScope; use proc_macro2::Span; use syn::{ Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Generics, Ident, Index, Member, Result, @@ -38,6 +39,7 @@ pub struct Field<'a> { pub attrs: Attrs<'a>, pub member: Member, pub ty: &'a Type, + pub contains_generic: bool, } impl<'a> Input<'a> { @@ -56,8 +58,9 @@ impl<'a> Input<'a> { impl<'a> Struct<'a> { fn from_syn(node: &'a DeriveInput, data: &'a DataStruct) -> Result { let mut attrs = attr::get(&node.attrs)?; + let scope = ParamsInScope::new(&node.generics); let span = attrs.span().unwrap_or_else(Span::call_site); - let fields = Field::multiple_from_syn(&data.fields, span)?; + let fields = Field::multiple_from_syn(&data.fields, &scope, span)?; if let Some(display) = &mut attrs.display { display.expand_shorthand(&fields); } @@ -74,12 +77,13 @@ impl<'a> Struct<'a> { impl<'a> Enum<'a> { fn from_syn(node: &'a DeriveInput, data: &'a DataEnum) -> Result { let attrs = attr::get(&node.attrs)?; + let scope = ParamsInScope::new(&node.generics); let span = attrs.span().unwrap_or_else(Span::call_site); let variants = data .variants .iter() .map(|node| { - let mut variant = Variant::from_syn(node, span)?; + let mut variant = Variant::from_syn(node, &scope, span)?; if let display @ None = &mut variant.attrs.display { *display = attrs.display.clone(); } @@ -102,28 +106,37 @@ impl<'a> Enum<'a> { } impl<'a> Variant<'a> { - fn from_syn(node: &'a syn::Variant, span: Span) -> Result { + fn from_syn(node: &'a syn::Variant, scope: &ParamsInScope<'a>, span: Span) -> Result { let attrs = attr::get(&node.attrs)?; let span = attrs.span().unwrap_or(span); Ok(Variant { original: node, attrs, ident: node.ident.clone(), - fields: Field::multiple_from_syn(&node.fields, span)?, + fields: Field::multiple_from_syn(&node.fields, scope, span)?, }) } } impl<'a> Field<'a> { - fn multiple_from_syn(fields: &'a Fields, span: Span) -> Result> { + fn multiple_from_syn( + fields: &'a Fields, + scope: &ParamsInScope<'a>, + span: Span, + ) -> Result> { fields .iter() .enumerate() - .map(|(i, field)| Field::from_syn(i, field, span)) + .map(|(i, field)| Field::from_syn(i, field, scope, span)) .collect() } - fn from_syn(i: usize, node: &'a syn::Field, span: Span) -> Result { + fn from_syn( + i: usize, + node: &'a syn::Field, + scope: &ParamsInScope<'a>, + span: Span, + ) -> Result { Ok(Field { original: node, attrs: attr::get(&node.attrs)?, @@ -134,6 +147,7 @@ impl<'a> Field<'a> { }) }), ty: &node.ty, + contains_generic: scope.intersects(&node.ty), }) } } diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 1ab1e28..9793586 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -1,5 +1,6 @@ use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; use quote::{format_ident, quote, ToTokens}; +use std::collections::BTreeSet as Set; use std::iter::FromIterator; use syn::parse::{Nothing, ParseStream}; use syn::{ @@ -21,6 +22,7 @@ pub struct Display<'a> { pub fmt: LitStr, pub args: TokenStream, pub has_bonus_display: bool, + pub implied_bounds: Set<(usize, Trait)>, } #[derive(Copy, Clone)] @@ -29,6 +31,12 @@ pub struct Transparent<'a> { pub span: Span, } +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub enum Trait { + Debug, + Display, +} + pub fn get(input: &[Attribute]) -> Result { let mut attrs = Attrs { display: None, @@ -91,6 +99,7 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu fmt: input.parse()?, args: parse_token_expr(input, false)?, has_bonus_display: false, + implied_bounds: Set::new(), }; if attrs.display.is_some() { return Err(Error::new_spanned( @@ -188,3 +197,12 @@ impl ToTokens for Display<'_> { }); } } + +impl ToTokens for Trait { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.extend(match self { + Trait::Debug => quote!(std::fmt::Debug), + Trait::Display => quote!(std::fmt::Display), + }); + } +} diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 19ce42a..3976ddd 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,8 +1,12 @@ use crate::ast::{Enum, Field, Input, Struct}; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::BTreeSet as Set; use syn::spanned::Spanned; -use syn::{Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility}; +use syn::{ + parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, + Visibility, +}; pub fn derive(node: &DeriveInput) -> Result { let input = Input::from_syn(node)?; @@ -16,6 +20,8 @@ pub fn derive(node: &DeriveInput) -> Result { fn impl_struct(input: Struct) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut error_generics = input.generics.clone(); + let error_where_clause = error_generics.make_where_clause(); let source_body = if input.attrs.transparent.is_some() { let only_field = &input.fields[0].member; @@ -24,6 +30,12 @@ fn impl_struct(input: Struct) -> TokenStream { }) } else if let Some(source_field) = input.source_field() { let source = &source_field.member; + if source_field.contains_generic { + let ty = unoptional_type(source_field.ty); + error_where_clause + .predicates + .push(parse_quote!(#ty: std::error::Error + 'static)); + } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) } else { @@ -89,12 +101,14 @@ fn impl_struct(input: Struct) -> TokenStream { } }); + let mut display_implied_bounds = &Set::new(); let display_body = if input.attrs.transparent.is_some() { let only_field = &input.fields[0].member; Some(quote! { std::fmt::Display::fmt(&self.#only_field, __formatter) }) } else if let Some(display) = &input.attrs.display { + display_implied_bounds = &display.implied_bounds; let use_as_display = if display.has_bonus_display { Some(quote! { #[allow(unused_imports)] @@ -114,9 +128,20 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let display_impl = display_body.map(|body| { + let mut display_generics = input.generics.clone(); + let display_where_clause = display_generics.make_where_clause(); + for &(field, bound) in display_implied_bounds { + let field = &input.fields[field]; + if field.contains_generic { + let field_ty = field.ty; + display_where_clause + .predicates + .push(parse_quote!(#field_ty: #bound)); + } + } quote! { #[allow(unused_qualifications)] - impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { + impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { #[allow(clippy::used_underscore_binding)] fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { #body @@ -141,10 +166,15 @@ fn impl_struct(input: Struct) -> TokenStream { }); let error_trait = spanned_error_trait(input.original); + if input.generics.type_params().next().is_some() { + error_where_clause + .predicates + .push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display)); + } quote! { #[allow(unused_qualifications)] - impl #impl_generics #error_trait for #ty #ty_generics #where_clause { + impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { #source_method #backtrace_method } @@ -156,6 +186,8 @@ fn impl_struct(input: Struct) -> TokenStream { fn impl_enum(input: Enum) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let mut error_generics = input.generics.clone(); + let error_where_clause = error_generics.make_where_clause(); let source_method = if input.has_source() { let arms = input.variants.iter().map(|variant| { @@ -168,6 +200,12 @@ fn impl_enum(input: Enum) -> TokenStream { } } else if let Some(source_field) = variant.source_field() { let source = &source_field.member; + if source_field.contains_generic { + let ty = unoptional_type(source_field.ty); + error_where_clause + .predicates + .push(parse_quote!(#ty: std::error::Error + 'static)); + } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) } else { @@ -286,6 +324,8 @@ fn impl_enum(input: Enum) -> TokenStream { }; let display_impl = if input.has_display() { + let mut display_generics = input.generics.clone(); + let display_where_clause = display_generics.make_where_clause(); let use_as_display = if input.variants.iter().any(|v| { v.attrs .display @@ -305,8 +345,12 @@ fn impl_enum(input: Enum) -> TokenStream { None }; let arms = input.variants.iter().map(|variant| { + let mut display_implied_bounds = &Set::new(); let display = match &variant.attrs.display { - Some(display) => display.to_token_stream(), + Some(display) => { + display_implied_bounds = &display.implied_bounds; + display.to_token_stream() + } None => { let only_field = match &variant.fields[0].member { Member::Named(ident) => ident.clone(), @@ -315,15 +359,25 @@ fn impl_enum(input: Enum) -> TokenStream { quote!(std::fmt::Display::fmt(#only_field, __formatter)) } }; + for &(field, bound) in display_implied_bounds { + let field = &variant.fields[field]; + if field.contains_generic { + let field_ty = field.ty; + display_where_clause + .predicates + .push(parse_quote!(#field_ty: #bound)); + } + } let ident = &variant.ident; let pat = fields_pat(&variant.fields); quote! { #ty::#ident #pat => #display } }); + let arms = arms.collect::>(); Some(quote! { #[allow(unused_qualifications)] - impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { + impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { #use_as_display #[allow(unused_variables, deprecated, clippy::used_underscore_binding)] @@ -355,10 +409,15 @@ fn impl_enum(input: Enum) -> TokenStream { }); let error_trait = spanned_error_trait(input.original); + if input.generics.type_params().next().is_some() { + error_where_clause + .predicates + .push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display)); + } quote! { #[allow(unused_qualifications)] - impl #impl_generics #error_trait for #ty #ty_generics #where_clause { + impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { #source_method #backtrace_method } diff --git a/impl/src/fmt.rs b/impl/src/fmt.rs index e12e94b..f214ec8 100644 --- a/impl/src/fmt.rs +++ b/impl/src/fmt.rs @@ -1,8 +1,8 @@ use crate::ast::Field; -use crate::attr::Display; +use crate::attr::{Display, Trait}; use proc_macro2::TokenTree; use quote::{format_ident, quote_spanned}; -use std::collections::HashSet as Set; +use std::collections::{BTreeSet as Set, HashMap as Map}; use syn::ext::IdentExt; use syn::parse::{ParseStream, Parser}; use syn::{Ident, Index, LitStr, Member, Result, Token}; @@ -12,7 +12,10 @@ impl Display<'_> { pub fn expand_shorthand(&mut self, fields: &[Field]) { let raw_args = self.args.clone(); let mut named_args = explicit_named_args.parse2(raw_args).unwrap(); - let fields: Set = fields.iter().map(|f| f.member.clone()).collect(); + let mut member_index = Map::new(); + for (i, field) in fields.iter().enumerate() { + member_index.insert(field.member.clone(), i); + } let span = self.fmt.span(); let fmt = self.fmt.value(); @@ -20,6 +23,7 @@ impl Display<'_> { let mut out = String::new(); let mut args = self.args.clone(); let mut has_bonus_display = false; + let mut implied_bounds = Set::new(); let mut has_trailing_comma = false; if let Some(TokenTree::Punct(punct)) = args.clone().into_iter().last() { @@ -47,7 +51,7 @@ impl Display<'_> { Ok(index) => Member::Unnamed(Index { index, span }), Err(_) => return, }; - if !fields.contains(&member) { + if !member_index.contains_key(&member) { out += ∫ continue; } @@ -82,9 +86,21 @@ impl Display<'_> { args.extend(quote_spanned!(span=> ,)); } args.extend(quote_spanned!(span=> #formatvar = #local)); - if read.starts_with('}') && fields.contains(&member) { - has_bonus_display = true; - args.extend(quote_spanned!(span=> .as_display())); + if let Some(&field) = member_index.get(&member) { + let end_spec = match read.find('}') { + Some(end_spec) => end_spec, + None => return, + }; + let bound = match read[..end_spec].chars().next_back() { + Some('?') => Trait::Debug, + Some(_) => Trait::Display, + None => { + has_bonus_display = true; + args.extend(quote_spanned!(span=> .as_display())); + Trait::Display + } + }; + implied_bounds.insert((field, bound)); } has_trailing_comma = false; } @@ -93,6 +109,7 @@ impl Display<'_> { self.fmt = LitStr::new(&out, self.fmt.span()); self.args = args; self.has_bonus_display = has_bonus_display; + self.implied_bounds = implied_bounds; } } diff --git a/impl/src/generics.rs b/impl/src/generics.rs new file mode 100644 index 0000000..ff77e50 --- /dev/null +++ b/impl/src/generics.rs @@ -0,0 +1,41 @@ +use std::collections::BTreeSet as Set; +use syn::{GenericArgument, Generics, Ident, PathArguments, Type}; + +pub struct ParamsInScope<'a> { + names: Set<&'a Ident>, +} + +impl<'a> ParamsInScope<'a> { + pub fn new(generics: &'a Generics) -> Self { + ParamsInScope { + names: generics.type_params().map(|param| ¶m.ident).collect(), + } + } + + pub fn intersects(&self, ty: &Type) -> bool { + let mut found = false; + crawl(self, ty, &mut found); + found + } +} + +fn crawl(in_scope: &ParamsInScope, ty: &Type, found: &mut bool) { + if let Type::Path(ty) = ty { + if ty.qself.is_none() { + if let Some(ident) = ty.path.get_ident() { + if in_scope.names.contains(ident) { + *found = true; + } + } + } + for segment in &ty.path.segments { + if let PathArguments::AngleBracketed(arguments) = &segment.arguments { + for arg in &arguments.args { + if let GenericArgument::Type(ty) = arg { + crawl(in_scope, ty, found); + } + } + } + } + } +} diff --git a/impl/src/lib.rs b/impl/src/lib.rs index f0577d4..a4d5ae7 100644 --- a/impl/src/lib.rs +++ b/impl/src/lib.rs @@ -16,6 +16,7 @@ mod ast; mod attr; mod expand; mod fmt; +mod generics; mod prop; mod valid; diff --git a/tests/test_generics.rs b/tests/test_generics.rs new file mode 100644 index 0000000..1297c76 --- /dev/null +++ b/tests/test_generics.rs @@ -0,0 +1,129 @@ +#![deny(clippy::all, clippy::pedantic)] + +use std::fmt::{self, Debug, Display}; +use thiserror::Error; + +pub struct NoFormat; + +#[derive(Debug)] +pub struct DebugOnly; + +pub struct DisplayOnly; + +impl Display for DisplayOnly { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("display only") + } +} + +#[derive(Debug)] +pub struct DebugAndDisplay; + +impl Display for DebugAndDisplay { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("debug and display") + } +} + +// Should expand to: +// +// impl Display for EnumDebugField +// where +// E: Debug; +// +// impl Error for EnumDebugField +// where +// Self: Debug + Display; +// +#[derive(Error, Debug)] +pub enum EnumDebugGeneric { + #[error("{0:?}")] + FatalError(E), +} + +// Should expand to: +// +// impl Display for EnumFromGeneric; +// +// impl Error for EnumFromGeneric +// where +// EnumDebugGeneric: Error + 'static, +// Self: Debug + Display; +// +#[derive(Error, Debug)] +pub enum EnumFromGeneric { + #[error("enum from generic")] + Source(#[from] EnumDebugGeneric), +} + +// Should expand to: +// +// impl Display +// for EnumCompound +// where +// HasDisplay: Display, +// HasDebug: Debug; +// +// impl Error +// for EnumCompound +// where +// Self: Debug + Display; +// +#[derive(Error)] +pub enum EnumCompound { + #[error("{0} {1:?}")] + DisplayDebug(HasDisplay, HasDebug), + #[error("{0}")] + Display(HasDisplay, HasNeither), + #[error("{1:?}")] + Debug(HasNeither, HasDebug), +} + +impl Debug for EnumCompound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("EnumCompound") + } +} + +#[test] +fn test_display_enum_compound() { + let mut instance: EnumCompound; + + instance = EnumCompound::DisplayDebug(DisplayOnly, DebugOnly); + assert_eq!(format!("{}", instance), "display only DebugOnly"); + + instance = EnumCompound::Display(DisplayOnly, NoFormat); + assert_eq!(format!("{}", instance), "display only"); + + instance = EnumCompound::Debug(NoFormat, DebugOnly); + assert_eq!(format!("{}", instance), "DebugOnly"); +} + +// Should expand to: +// +// impl Display for StructDebugGeneric +// where +// E: Debug; +// +// impl Error for StructDebugGeneric +// where +// Self: Debug + Display; +// +#[derive(Error, Debug)] +#[error("{underlying:?}")] +pub struct StructDebugGeneric { + pub underlying: E, +} + +// Should expand to: +// +// impl Error for StructFromGeneric +// where +// StructDebugGeneric: Error + 'static, +// Self: Debug + Display; +// +#[derive(Error, Debug)] +pub struct StructFromGeneric { + #[from] + pub source: StructDebugGeneric, +}