diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 3976ddd..9202276 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,4 +1,5 @@ use crate::ast::{Enum, Field, Input, Struct}; +use crate::attr::Trait; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use std::collections::BTreeSet as Set; @@ -24,9 +25,16 @@ fn impl_struct(input: Struct) -> TokenStream { let error_where_clause = error_generics.make_where_clause(); let source_body = if input.attrs.transparent.is_some() { - let only_field = &input.fields[0].member; + let only_field = &input.fields[0]; + if only_field.contains_generic { + let ty = only_field.ty; + error_where_clause + .predicates + .push(parse_quote!(#ty: std::error::Error)); + } + let member = &only_field.member; Some(quote! { - std::error::Error::source(self.#only_field.as_dyn_error()) + std::error::Error::source(self.#member.as_dyn_error()) }) } else if let Some(source_field) = input.source_field() { let source = &source_field.member; @@ -101,14 +109,15 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let mut display_implied_bounds = &Set::new(); + let mut display_implied_bounds = Set::new(); let display_body = if input.attrs.transparent.is_some() { let only_field = &input.fields[0].member; + display_implied_bounds.insert((0, Trait::Display)); Some(quote! { std::fmt::Display::fmt(&self.#only_field, __formatter) }) } else if let Some(display) = &input.attrs.display { - display_implied_bounds = &display.implied_bounds; + display_implied_bounds = display.implied_bounds.clone(); let use_as_display = if display.has_bonus_display { Some(quote! { #[allow(unused_imports)] @@ -130,7 +139,7 @@ fn impl_struct(input: Struct) -> TokenStream { let display_impl = display_body.map(|body| { let mut display_generics = input.generics.clone(); let display_where_clause = display_generics.make_where_clause(); - for &(field, bound) in display_implied_bounds { + for (field, bound) in display_implied_bounds { let field = &input.fields[field]; if field.contains_generic { let field_ty = field.ty; @@ -193,10 +202,17 @@ fn impl_enum(input: Enum) -> TokenStream { let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; if variant.attrs.transparent.is_some() { - let only_field = &variant.fields[0].member; + let only_field = &variant.fields[0]; + if only_field.contains_generic { + let ty = only_field.ty; + error_where_clause + .predicates + .push(parse_quote!(#ty: std::error::Error)); + } + let member = &only_field.member; let source = quote!(std::error::Error::source(transparent.as_dyn_error())); quote! { - #ty::#ident {#only_field: transparent} => #source, + #ty::#ident {#member: transparent} => #source, } } else if let Some(source_field) = variant.source_field() { let source = &source_field.member; @@ -345,10 +361,10 @@ fn impl_enum(input: Enum) -> TokenStream { None }; let arms = input.variants.iter().map(|variant| { - let mut display_implied_bounds = &Set::new(); + let mut display_implied_bounds = Set::new(); let display = match &variant.attrs.display { Some(display) => { - display_implied_bounds = &display.implied_bounds; + display_implied_bounds = display.implied_bounds.clone(); display.to_token_stream() } None => { @@ -356,10 +372,11 @@ fn impl_enum(input: Enum) -> TokenStream { Member::Named(ident) => ident.clone(), Member::Unnamed(index) => format_ident!("_{}", index), }; + display_implied_bounds.insert((0, Trait::Display)); quote!(std::fmt::Display::fmt(#only_field, __formatter)) } }; - for &(field, bound) in display_implied_bounds { + for (field, bound) in display_implied_bounds { let field = &variant.fields[field]; if field.contains_generic { let field_ty = field.ty; diff --git a/tests/test_generics.rs b/tests/test_generics.rs index 1297c76..f5e1de2 100644 --- a/tests/test_generics.rs +++ b/tests/test_generics.rs @@ -99,6 +99,23 @@ fn test_display_enum_compound() { assert_eq!(format!("{}", instance), "DebugOnly"); } +// Should expand to: +// +// impl Display for EnumTransparentGeneric +// where +// E: Display; +// +// impl Error for EnumTransparentGeneric +// where +// E: Error, +// Self: Debug + Display; +// +#[derive(Error, Debug)] +pub enum EnumTransparentGeneric { + #[error(transparent)] + Other(E), +} + // Should expand to: // // impl Display for StructDebugGeneric @@ -127,3 +144,18 @@ pub struct StructFromGeneric { #[from] pub source: StructDebugGeneric, } + +// Should expand to: +// +// impl Display for StructTransparentGeneric +// where +// E: Display; +// +// impl Error for StructTransparentGeneric +// where +// E: Error, +// Self: Debug + Display; +// +#[derive(Error, Debug)] +#[error(transparent)] +pub struct StructTransparentGeneric(E);