Skip to content

Commit

Permalink
Merge pull request #151 from dtolnay/bounds
Browse files Browse the repository at this point in the history
Deduplicate inferred bounds
  • Loading branch information
dtolnay committed Sep 5, 2021
2 parents 34f5931 + 72abba6 commit 113fcaa
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
60 changes: 22 additions & 38 deletions impl/src/expand.rs
@@ -1,12 +1,12 @@
use crate::ast::{Enum, Field, Input, Struct};
use crate::attr::Trait;
use crate::generics::InferredBounds;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::spanned::Spanned;
use syn::{
parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type,
Visibility,
Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
Expand All @@ -21,16 +21,12 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let mut error_inferred_bounds = InferredBounds::new();

let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
}
let member = &only_field.member;
Some(quote! {
Expand All @@ -40,9 +36,7 @@ fn impl_struct(input: Struct) -> TokenStream {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
Expand Down Expand Up @@ -137,17 +131,14 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let mut display_inferred_bounds = InferredBounds::new();
for (field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
display_inferred_bounds.insert(field.ty, bound);
}
}
let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
Expand Down Expand Up @@ -176,10 +167,11 @@ fn impl_struct(input: Struct) -> TokenStream {

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
let self_token = <Token![Self]>::default();
error_inferred_bounds.insert(self_token, Trait::Debug);
error_inferred_bounds.insert(self_token, Trait::Display);
}
let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);

quote! {
#[allow(unused_qualifications)]
Expand All @@ -195,19 +187,15 @@ fn impl_struct(input: Struct) -> TokenStream {
fn impl_enum(input: Enum) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let mut error_inferred_bounds = InferredBounds::new();

let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
if variant.attrs.transparent.is_some() {
let only_field = &variant.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
}
let member = &only_field.member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
Expand All @@ -218,9 +206,7 @@ fn impl_enum(input: Enum) -> TokenStream {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
Expand Down Expand Up @@ -340,8 +326,7 @@ fn impl_enum(input: Enum) -> TokenStream {
};

let display_impl = if input.has_display() {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let mut display_inferred_bounds = InferredBounds::new();
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
.display
Expand Down Expand Up @@ -379,10 +364,7 @@ fn impl_enum(input: Enum) -> TokenStream {
for (field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
display_inferred_bounds.insert(field.ty, bound);
}
}
let ident = &variant.ident;
Expand All @@ -392,6 +374,7 @@ fn impl_enum(input: Enum) -> TokenStream {
}
});
let arms = arms.collect::<Vec<_>>();
let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
Expand Down Expand Up @@ -427,10 +410,11 @@ fn impl_enum(input: Enum) -> TokenStream {

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
let self_token = <Token![Self]>::default();
error_inferred_bounds.insert(self_token, Trait::Debug);
error_inferred_bounds.insert(self_token, Trait::Display);
}
let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);

quote! {
#[allow(unused_qualifications)]
Expand Down
45 changes: 43 additions & 2 deletions impl/src/generics.rs
@@ -1,5 +1,9 @@
use std::collections::BTreeSet as Set;
use syn::{GenericArgument, Generics, Ident, PathArguments, Type};
use proc_macro2::TokenStream;
use quote::ToTokens;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap as Map, BTreeSet as Set};
use syn::punctuated::Punctuated;
use syn::{parse_quote, GenericArgument, Generics, Ident, PathArguments, Token, Type, WhereClause};

pub struct ParamsInScope<'a> {
names: Set<&'a Ident>,
Expand Down Expand Up @@ -39,3 +43,40 @@ fn crawl(in_scope: &ParamsInScope, ty: &Type, found: &mut bool) {
}
}
}

pub struct InferredBounds {
bounds: Map<String, (Set<String>, Punctuated<TokenStream, Token![+]>)>,
order: Vec<TokenStream>,
}

impl InferredBounds {
pub fn new() -> Self {
InferredBounds {
bounds: Map::new(),
order: Vec::new(),
}
}

pub fn insert(&mut self, ty: impl ToTokens, bound: impl ToTokens) {
let ty = ty.to_token_stream();
let bound = bound.to_token_stream();
let entry = self.bounds.entry(ty.to_string());
if let Entry::Vacant(_) = entry {
self.order.push(ty);
}
let (set, tokens) = entry.or_default();
if set.insert(bound.to_string()) {
tokens.push(bound);
}
}

pub fn augment_where_clause(&self, generics: &Generics) -> WhereClause {
let mut generics = generics.clone();
let where_clause = generics.make_where_clause();
for ty in &self.order {
let (_set, bounds) = &self.bounds[&ty.to_string()];
where_clause.predicates.push(parse_quote!(#ty: #bounds));
}
generics.where_clause.unwrap()
}
}

0 comments on commit 113fcaa

Please sign in to comment.