Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate inferred bounds #151

Merged
merged 1 commit into from Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 22 additions & 38 deletions impl/src/expand.rs
@@ -1,12 +1,12 @@
use crate::ast::{Enum, Field, Input, Struct};
use crate::attr::Trait;
use crate::generics::InferredBounds;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::spanned::Spanned;
use syn::{
parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type,
Visibility,
Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
Expand All @@ -21,16 +21,12 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let mut error_inferred_bounds = InferredBounds::new();

let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
}
let member = &only_field.member;
Some(quote! {
Expand All @@ -40,9 +36,7 @@ fn impl_struct(input: Struct) -> TokenStream {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
Expand Down Expand Up @@ -137,17 +131,14 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let mut display_inferred_bounds = InferredBounds::new();
for (field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
display_inferred_bounds.insert(field.ty, bound);
}
}
let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
Expand Down Expand Up @@ -176,10 +167,11 @@ fn impl_struct(input: Struct) -> TokenStream {

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
let self_token = <Token![Self]>::default();
error_inferred_bounds.insert(self_token, Trait::Debug);
error_inferred_bounds.insert(self_token, Trait::Display);
}
let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);

quote! {
#[allow(unused_qualifications)]
Expand All @@ -195,19 +187,15 @@ fn impl_struct(input: Struct) -> TokenStream {
fn impl_enum(input: Enum) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();
let mut error_inferred_bounds = InferredBounds::new();

let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
if variant.attrs.transparent.is_some() {
let only_field = &variant.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
}
let member = &only_field.member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
Expand All @@ -218,9 +206,7 @@ fn impl_enum(input: Enum) -> TokenStream {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
Expand Down Expand Up @@ -340,8 +326,7 @@ fn impl_enum(input: Enum) -> TokenStream {
};

let display_impl = if input.has_display() {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let mut display_inferred_bounds = InferredBounds::new();
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
.display
Expand Down Expand Up @@ -379,10 +364,7 @@ fn impl_enum(input: Enum) -> TokenStream {
for (field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
display_inferred_bounds.insert(field.ty, bound);
}
}
let ident = &variant.ident;
Expand All @@ -392,6 +374,7 @@ fn impl_enum(input: Enum) -> TokenStream {
}
});
let arms = arms.collect::<Vec<_>>();
let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
Expand Down Expand Up @@ -427,10 +410,11 @@ fn impl_enum(input: Enum) -> TokenStream {

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
let self_token = <Token![Self]>::default();
error_inferred_bounds.insert(self_token, Trait::Debug);
error_inferred_bounds.insert(self_token, Trait::Display);
}
let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);

quote! {
#[allow(unused_qualifications)]
Expand Down
45 changes: 43 additions & 2 deletions impl/src/generics.rs
@@ -1,5 +1,9 @@
use std::collections::BTreeSet as Set;
use syn::{GenericArgument, Generics, Ident, PathArguments, Type};
use proc_macro2::TokenStream;
use quote::ToTokens;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap as Map, BTreeSet as Set};
use syn::punctuated::Punctuated;
use syn::{parse_quote, GenericArgument, Generics, Ident, PathArguments, Token, Type, WhereClause};

pub struct ParamsInScope<'a> {
names: Set<&'a Ident>,
Expand Down Expand Up @@ -39,3 +43,40 @@ fn crawl(in_scope: &ParamsInScope, ty: &Type, found: &mut bool) {
}
}
}

pub struct InferredBounds {
bounds: Map<String, (Set<String>, Punctuated<TokenStream, Token![+]>)>,
order: Vec<TokenStream>,
}

impl InferredBounds {
pub fn new() -> Self {
InferredBounds {
bounds: Map::new(),
order: Vec::new(),
}
}

pub fn insert(&mut self, ty: impl ToTokens, bound: impl ToTokens) {
let ty = ty.to_token_stream();
let bound = bound.to_token_stream();
let entry = self.bounds.entry(ty.to_string());
if let Entry::Vacant(_) = entry {
self.order.push(ty);
}
let (set, tokens) = entry.or_default();
if set.insert(bound.to_string()) {
tokens.push(bound);
}
}

pub fn augment_where_clause(&self, generics: &Generics) -> WhereClause {
let mut generics = generics.clone();
let where_clause = generics.make_where_clause();
for ty in &self.order {
let (_set, bounds) = &self.bounds[&ty.to_string()];
where_clause.predicates.push(parse_quote!(#ty: #bounds));
}
generics.where_clause.unwrap()
}
}