From 72abba6f046ddfd1dc590b3ba8f4d9de8856a4bc Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sat, 4 Sep 2021 18:47:33 -0700 Subject: [PATCH] Deduplicate inferred bounds --- impl/src/expand.rs | 60 ++++++++++++++++---------------------------- impl/src/generics.rs | 45 +++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 9202276..435ad48 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,12 +1,12 @@ use crate::ast::{Enum, Field, Input, Struct}; use crate::attr::Trait; +use crate::generics::InferredBounds; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use std::collections::BTreeSet as Set; use syn::spanned::Spanned; use syn::{ - parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, - Visibility, + Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility, }; pub fn derive(node: &DeriveInput) -> Result { @@ -21,16 +21,12 @@ pub fn derive(node: &DeriveInput) -> Result { fn impl_struct(input: Struct) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut error_generics = input.generics.clone(); - let error_where_clause = error_generics.make_where_clause(); + let mut error_inferred_bounds = InferredBounds::new(); let source_body = if input.attrs.transparent.is_some() { let only_field = &input.fields[0]; if only_field.contains_generic { - let ty = only_field.ty; - error_where_clause - .predicates - .push(parse_quote!(#ty: std::error::Error)); + error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); } let member = &only_field.member; Some(quote! { @@ -40,9 +36,7 @@ fn impl_struct(input: Struct) -> TokenStream { let source = &source_field.member; if source_field.contains_generic { let ty = unoptional_type(source_field.ty); - error_where_clause - .predicates - .push(parse_quote!(#ty: std::error::Error + 'static)); + error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) @@ -137,17 +131,14 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let display_impl = display_body.map(|body| { - let mut display_generics = input.generics.clone(); - let display_where_clause = display_generics.make_where_clause(); + let mut display_inferred_bounds = InferredBounds::new(); for (field, bound) in display_implied_bounds { let field = &input.fields[field]; if field.contains_generic { - let field_ty = field.ty; - display_where_clause - .predicates - .push(parse_quote!(#field_ty: #bound)); + display_inferred_bounds.insert(field.ty, bound); } } + let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { @@ -176,10 +167,11 @@ fn impl_struct(input: Struct) -> TokenStream { let error_trait = spanned_error_trait(input.original); if input.generics.type_params().next().is_some() { - error_where_clause - .predicates - .push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display)); + let self_token = ::default(); + error_inferred_bounds.insert(self_token, Trait::Debug); + error_inferred_bounds.insert(self_token, Trait::Display); } + let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] @@ -195,8 +187,7 @@ fn impl_struct(input: Struct) -> TokenStream { fn impl_enum(input: Enum) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let mut error_generics = input.generics.clone(); - let error_where_clause = error_generics.make_where_clause(); + let mut error_inferred_bounds = InferredBounds::new(); let source_method = if input.has_source() { let arms = input.variants.iter().map(|variant| { @@ -204,10 +195,7 @@ fn impl_enum(input: Enum) -> TokenStream { if variant.attrs.transparent.is_some() { let only_field = &variant.fields[0]; if only_field.contains_generic { - let ty = only_field.ty; - error_where_clause - .predicates - .push(parse_quote!(#ty: std::error::Error)); + error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); } let member = &only_field.member; let source = quote!(std::error::Error::source(transparent.as_dyn_error())); @@ -218,9 +206,7 @@ fn impl_enum(input: Enum) -> TokenStream { let source = &source_field.member; if source_field.contains_generic { let ty = unoptional_type(source_field.ty); - error_where_clause - .predicates - .push(parse_quote!(#ty: std::error::Error + 'static)); + error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); } let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) @@ -340,8 +326,7 @@ fn impl_enum(input: Enum) -> TokenStream { }; let display_impl = if input.has_display() { - let mut display_generics = input.generics.clone(); - let display_where_clause = display_generics.make_where_clause(); + let mut display_inferred_bounds = InferredBounds::new(); let use_as_display = if input.variants.iter().any(|v| { v.attrs .display @@ -379,10 +364,7 @@ fn impl_enum(input: Enum) -> TokenStream { for (field, bound) in display_implied_bounds { let field = &variant.fields[field]; if field.contains_generic { - let field_ty = field.ty; - display_where_clause - .predicates - .push(parse_quote!(#field_ty: #bound)); + display_inferred_bounds.insert(field.ty, bound); } } let ident = &variant.ident; @@ -392,6 +374,7 @@ fn impl_enum(input: Enum) -> TokenStream { } }); let arms = arms.collect::>(); + let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); Some(quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { @@ -427,10 +410,11 @@ fn impl_enum(input: Enum) -> TokenStream { let error_trait = spanned_error_trait(input.original); if input.generics.type_params().next().is_some() { - error_where_clause - .predicates - .push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display)); + let self_token = ::default(); + error_inferred_bounds.insert(self_token, Trait::Debug); + error_inferred_bounds.insert(self_token, Trait::Display); } + let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); quote! { #[allow(unused_qualifications)] diff --git a/impl/src/generics.rs b/impl/src/generics.rs index ff77e50..254c2ed 100644 --- a/impl/src/generics.rs +++ b/impl/src/generics.rs @@ -1,5 +1,9 @@ -use std::collections::BTreeSet as Set; -use syn::{GenericArgument, Generics, Ident, PathArguments, Type}; +use proc_macro2::TokenStream; +use quote::ToTokens; +use std::collections::btree_map::Entry; +use std::collections::{BTreeMap as Map, BTreeSet as Set}; +use syn::punctuated::Punctuated; +use syn::{parse_quote, GenericArgument, Generics, Ident, PathArguments, Token, Type, WhereClause}; pub struct ParamsInScope<'a> { names: Set<&'a Ident>, @@ -39,3 +43,40 @@ fn crawl(in_scope: &ParamsInScope, ty: &Type, found: &mut bool) { } } } + +pub struct InferredBounds { + bounds: Map, Punctuated)>, + order: Vec, +} + +impl InferredBounds { + pub fn new() -> Self { + InferredBounds { + bounds: Map::new(), + order: Vec::new(), + } + } + + pub fn insert(&mut self, ty: impl ToTokens, bound: impl ToTokens) { + let ty = ty.to_token_stream(); + let bound = bound.to_token_stream(); + let entry = self.bounds.entry(ty.to_string()); + if let Entry::Vacant(_) = entry { + self.order.push(ty); + } + let (set, tokens) = entry.or_default(); + if set.insert(bound.to_string()) { + tokens.push(bound); + } + } + + pub fn augment_where_clause(&self, generics: &Generics) -> WhereClause { + let mut generics = generics.clone(); + let where_clause = generics.make_where_clause(); + for ty in &self.order { + let (_set, bounds) = &self.bounds[&ty.to_string()]; + where_clause.predicates.push(parse_quote!(#ty: #bounds)); + } + generics.where_clause.unwrap() + } +}