Skip to content

Commit

Permalink
Merge pull request #148 from dtolnay/bounds
Browse files Browse the repository at this point in the history
Implied bounds for Display and Error impl
  • Loading branch information
dtolnay committed Sep 5, 2021
2 parents ec9ac76 + 1e6e267 commit e95b4ad
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 20 deletions.
28 changes: 21 additions & 7 deletions impl/src/ast.rs
@@ -1,4 +1,5 @@
use crate::attr::{self, Attrs};
use crate::generics::ParamsInScope;
use proc_macro2::Span;
use syn::{
Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Generics, Ident, Index, Member, Result,
Expand Down Expand Up @@ -38,6 +39,7 @@ pub struct Field<'a> {
pub attrs: Attrs<'a>,
pub member: Member,
pub ty: &'a Type,
pub contains_generic: bool,
}

impl<'a> Input<'a> {
Expand All @@ -56,8 +58,9 @@ impl<'a> Input<'a> {
impl<'a> Struct<'a> {
fn from_syn(node: &'a DeriveInput, data: &'a DataStruct) -> Result<Self> {
let mut attrs = attr::get(&node.attrs)?;
let scope = ParamsInScope::new(&node.generics);
let span = attrs.span().unwrap_or_else(Span::call_site);
let fields = Field::multiple_from_syn(&data.fields, span)?;
let fields = Field::multiple_from_syn(&data.fields, &scope, span)?;
if let Some(display) = &mut attrs.display {
display.expand_shorthand(&fields);
}
Expand All @@ -74,12 +77,13 @@ impl<'a> Struct<'a> {
impl<'a> Enum<'a> {
fn from_syn(node: &'a DeriveInput, data: &'a DataEnum) -> Result<Self> {
let attrs = attr::get(&node.attrs)?;
let scope = ParamsInScope::new(&node.generics);
let span = attrs.span().unwrap_or_else(Span::call_site);
let variants = data
.variants
.iter()
.map(|node| {
let mut variant = Variant::from_syn(node, span)?;
let mut variant = Variant::from_syn(node, &scope, span)?;
if let display @ None = &mut variant.attrs.display {
*display = attrs.display.clone();
}
Expand All @@ -102,28 +106,37 @@ impl<'a> Enum<'a> {
}

impl<'a> Variant<'a> {
fn from_syn(node: &'a syn::Variant, span: Span) -> Result<Self> {
fn from_syn(node: &'a syn::Variant, scope: &ParamsInScope<'a>, span: Span) -> Result<Self> {
let attrs = attr::get(&node.attrs)?;
let span = attrs.span().unwrap_or(span);
Ok(Variant {
original: node,
attrs,
ident: node.ident.clone(),
fields: Field::multiple_from_syn(&node.fields, span)?,
fields: Field::multiple_from_syn(&node.fields, scope, span)?,
})
}
}

impl<'a> Field<'a> {
fn multiple_from_syn(fields: &'a Fields, span: Span) -> Result<Vec<Self>> {
fn multiple_from_syn(
fields: &'a Fields,
scope: &ParamsInScope<'a>,
span: Span,
) -> Result<Vec<Self>> {
fields
.iter()
.enumerate()
.map(|(i, field)| Field::from_syn(i, field, span))
.map(|(i, field)| Field::from_syn(i, field, scope, span))
.collect()
}

fn from_syn(i: usize, node: &'a syn::Field, span: Span) -> Result<Self> {
fn from_syn(
i: usize,
node: &'a syn::Field,
scope: &ParamsInScope<'a>,
span: Span,
) -> Result<Self> {
Ok(Field {
original: node,
attrs: attr::get(&node.attrs)?,
Expand All @@ -134,6 +147,7 @@ impl<'a> Field<'a> {
})
}),
ty: &node.ty,
contains_generic: scope.intersects(&node.ty),
})
}
}
Expand Down
18 changes: 18 additions & 0 deletions impl/src/attr.rs
@@ -1,5 +1,6 @@
use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, ToTokens};
use std::collections::BTreeSet as Set;
use std::iter::FromIterator;
use syn::parse::{Nothing, ParseStream};
use syn::{
Expand All @@ -21,6 +22,7 @@ pub struct Display<'a> {
pub fmt: LitStr,
pub args: TokenStream,
pub has_bonus_display: bool,
pub implied_bounds: Set<(usize, Trait)>,
}

#[derive(Copy, Clone)]
Expand All @@ -29,6 +31,12 @@ pub struct Transparent<'a> {
pub span: Span,
}

#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum Trait {
Debug,
Display,
}

pub fn get(input: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs {
display: None,
Expand Down Expand Up @@ -91,6 +99,7 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu
fmt: input.parse()?,
args: parse_token_expr(input, false)?,
has_bonus_display: false,
implied_bounds: Set::new(),
};
if attrs.display.is_some() {
return Err(Error::new_spanned(
Expand Down Expand Up @@ -188,3 +197,12 @@ impl ToTokens for Display<'_> {
});
}
}

impl ToTokens for Trait {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend(match self {
Trait::Debug => quote!(std::fmt::Debug),
Trait::Display => quote!(std::fmt::Display),
});
}
}
71 changes: 65 additions & 6 deletions impl/src/expand.rs
@@ -1,8 +1,12 @@
use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility};
use syn::{
parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type,
Visibility,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand All @@ -16,6 +20,8 @@ pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();

let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
Expand All @@ -24,6 +30,12 @@ fn impl_struct(input: Struct) -> TokenStream {
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
Expand Down Expand Up @@ -89,12 +101,14 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});

let mut display_implied_bounds = &Set::new();
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
display_implied_bounds = &display.implied_bounds;
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
Expand All @@ -114,9 +128,20 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
for &(field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
}
}
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
#[allow(clippy::used_underscore_binding)]
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#body
Expand All @@ -141,10 +166,15 @@ fn impl_struct(input: Struct) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
}

quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}
Expand All @@ -156,6 +186,8 @@ fn impl_struct(input: Struct) -> TokenStream {
fn impl_enum(input: Enum) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut error_generics = input.generics.clone();
let error_where_clause = error_generics.make_where_clause();

let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
Expand All @@ -168,6 +200,12 @@ fn impl_enum(input: Enum) -> TokenStream {
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
if source_field.contains_generic {
let ty = unoptional_type(source_field.ty);
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error + 'static));
}
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
Expand Down Expand Up @@ -286,6 +324,8 @@ fn impl_enum(input: Enum) -> TokenStream {
};

let display_impl = if input.has_display() {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
let use_as_display = if input.variants.iter().any(|v| {
v.attrs
.display
Expand All @@ -305,8 +345,12 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
let mut display_implied_bounds = &Set::new();
let display = match &variant.attrs.display {
Some(display) => display.to_token_stream(),
Some(display) => {
display_implied_bounds = &display.implied_bounds;
display.to_token_stream()
}
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
Expand All @@ -315,15 +359,25 @@ fn impl_enum(input: Enum) -> TokenStream {
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
for &(field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
display_where_clause
.predicates
.push(parse_quote!(#field_ty: #bound));
}
}
let ident = &variant.ident;
let pat = fields_pat(&variant.fields);
quote! {
#ty::#ident #pat => #display
}
});
let arms = arms.collect::<Vec<_>>();
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
#[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
Expand Down Expand Up @@ -355,10 +409,15 @@ fn impl_enum(input: Enum) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);
if input.generics.type_params().next().is_some() {
error_where_clause
.predicates
.push(parse_quote!(Self: std::fmt::Debug + std::fmt::Display));
}

quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
#source_method
#backtrace_method
}
Expand Down
31 changes: 24 additions & 7 deletions impl/src/fmt.rs
@@ -1,8 +1,8 @@
use crate::ast::Field;
use crate::attr::Display;
use crate::attr::{Display, Trait};
use proc_macro2::TokenTree;
use quote::{format_ident, quote_spanned};
use std::collections::HashSet as Set;
use std::collections::{BTreeSet as Set, HashMap as Map};
use syn::ext::IdentExt;
use syn::parse::{ParseStream, Parser};
use syn::{Ident, Index, LitStr, Member, Result, Token};
Expand All @@ -12,14 +12,18 @@ impl Display<'_> {
pub fn expand_shorthand(&mut self, fields: &[Field]) {
let raw_args = self.args.clone();
let mut named_args = explicit_named_args.parse2(raw_args).unwrap();
let fields: Set<Member> = fields.iter().map(|f| f.member.clone()).collect();
let mut member_index = Map::new();
for (i, field) in fields.iter().enumerate() {
member_index.insert(field.member.clone(), i);
}

let span = self.fmt.span();
let fmt = self.fmt.value();
let mut read = fmt.as_str();
let mut out = String::new();
let mut args = self.args.clone();
let mut has_bonus_display = false;
let mut implied_bounds = Set::new();

let mut has_trailing_comma = false;
if let Some(TokenTree::Punct(punct)) = args.clone().into_iter().last() {
Expand Down Expand Up @@ -47,7 +51,7 @@ impl Display<'_> {
Ok(index) => Member::Unnamed(Index { index, span }),
Err(_) => return,
};
if !fields.contains(&member) {
if !member_index.contains_key(&member) {
out += &int;
continue;
}
Expand Down Expand Up @@ -82,9 +86,21 @@ impl Display<'_> {
args.extend(quote_spanned!(span=> ,));
}
args.extend(quote_spanned!(span=> #formatvar = #local));
if read.starts_with('}') && fields.contains(&member) {
has_bonus_display = true;
args.extend(quote_spanned!(span=> .as_display()));
if let Some(&field) = member_index.get(&member) {
let end_spec = match read.find('}') {
Some(end_spec) => end_spec,
None => return,
};
let bound = match read[..end_spec].chars().next_back() {
Some('?') => Trait::Debug,
Some(_) => Trait::Display,
None => {
has_bonus_display = true;
args.extend(quote_spanned!(span=> .as_display()));
Trait::Display
}
};
implied_bounds.insert((field, bound));
}
has_trailing_comma = false;
}
Expand All @@ -93,6 +109,7 @@ impl Display<'_> {
self.fmt = LitStr::new(&out, self.fmt.span());
self.args = args;
self.has_bonus_display = has_bonus_display;
self.implied_bounds = implied_bounds;
}
}

Expand Down

0 comments on commit e95b4ad

Please sign in to comment.