Skip to content

Commit

Permalink
Implied bounds for transparent attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
dtolnay committed Sep 5, 2021
1 parent 42d36e5 commit 3e699aa
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions impl/src/expand.rs
@@ -1,4 +1,5 @@
use crate::ast::{Enum, Field, Input, Struct};
use crate::attr::Trait;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
Expand All @@ -24,9 +25,16 @@ fn impl_struct(input: Struct) -> TokenStream {
let error_where_clause = error_generics.make_where_clause();

let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
let only_field = &input.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
}
let member = &only_field.member;
Some(quote! {
std::error::Error::source(self.#only_field.as_dyn_error())
std::error::Error::source(self.#member.as_dyn_error())
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
Expand Down Expand Up @@ -101,14 +109,15 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});

let mut display_implied_bounds = &Set::new();
let mut display_implied_bounds = Set::new();
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
display_implied_bounds.insert((0, Trait::Display));
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
display_implied_bounds = &display.implied_bounds;
display_implied_bounds = display.implied_bounds.clone();
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
Expand All @@ -130,7 +139,7 @@ fn impl_struct(input: Struct) -> TokenStream {
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
for &(field, bound) in display_implied_bounds {
for (field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
Expand Down Expand Up @@ -193,10 +202,17 @@ fn impl_enum(input: Enum) -> TokenStream {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
if variant.attrs.transparent.is_some() {
let only_field = &variant.fields[0].member;
let only_field = &variant.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
}
let member = &only_field.member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
quote! {
#ty::#ident {#only_field: transparent} => #source,
#ty::#ident {#member: transparent} => #source,
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
Expand Down Expand Up @@ -345,21 +361,22 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
let mut display_implied_bounds = &Set::new();
let mut display_implied_bounds = Set::new();
let display = match &variant.attrs.display {
Some(display) => {
display_implied_bounds = &display.implied_bounds;
display_implied_bounds = display.implied_bounds.clone();
display.to_token_stream()
}
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
Member::Unnamed(index) => format_ident!("_{}", index),
};
display_implied_bounds.insert((0, Trait::Display));
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
for &(field, bound) in display_implied_bounds {
for (field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
Expand Down

0 comments on commit 3e699aa

Please sign in to comment.