From 8d2c7d4753c0dfca5c049dbaff74b0912ddd944d Mon Sep 17 00:00:00 2001 From: Zoey Date: Sat, 28 Aug 2021 01:48:24 -0700 Subject: [PATCH 1/7] Initial rough-out for enum generic handling --- impl/src/attr.rs | 20 ++++++++ impl/src/expand.rs | 82 +++++++++++++++++++++++++++++- tests/test_generics.rs | 112 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 tests/test_generics.rs diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 1ab1e28..82f3d98 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -15,6 +15,12 @@ pub struct Attrs<'a> { pub transparent: Option>, } +#[derive(Clone, Copy)] +pub enum DisplayFormatMarking<'a> { + Debug(&'a crate::ast::Field<'a>), + Display(&'a crate::ast::Field<'a>), +} + #[derive(Clone)] pub struct Display<'a> { pub original: &'a Attribute, @@ -23,6 +29,20 @@ pub struct Display<'a> { pub has_bonus_display: bool, } +impl<'a> Display<'a> { + pub fn iter_fmt_types( + &'a self, + fields: &'a [crate::ast::Field], + ) -> impl Iterator + 'a { + // TODO: Parse format string literal, return only fields at offsets which are Display or Debug referenced by each format reference + // If a field position is referenced multiple times for the same format class, deduplicate + fields + .iter() + .map(DisplayFormatMarking::Display) + .chain(fields.iter().map(DisplayFormatMarking::Debug)) + } +} + #[derive(Copy, Clone)] pub struct Transparent<'a> { pub original: &'a Attribute, diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 789eee6..58a7293 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -2,7 +2,7 @@ use crate::ast::{Enum, Field, Input, Struct}; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility}; +use syn::{parse_quote, Data, DeriveInput, Member, PathArguments, Result, Type, Visibility}; pub fn derive(node: &DeriveInput) -> Result { let input = Input::from_syn(node)?; @@ -285,6 +285,50 @@ fn impl_enum(input: Enum) -> TokenStream { } else { None }; + + let mut extra_predicates = Vec::new(); + for (fields, display) in input + .variants + .iter() + .filter_map(|v| v.attrs.display.as_ref().map(|d| (v.fields.as_slice(), d))) + { + for field in display.iter_fmt_types(fields) { + use crate::attr::DisplayFormatMarking; + let (ty, bound): ( + &syn::Type, + syn::punctuated::Punctuated, + ); + match field { + DisplayFormatMarking::Debug(f) => { + ty = &f.ty; + bound = parse_quote! { ::std::fmt::Debug }; + } + DisplayFormatMarking::Display(f) => { + ty = &f.ty; + bound = parse_quote! { ::std::fmt::Display }; + } + } + let matcher = |param: &&syn::TypeParam| match ty { + syn::Type::Path(syn::TypePath { path, .. }) => { + path.get_ident() == Some(¶m.ident) + } + syn::Type::Reference(syn::TypeReference { elem, .. }) if matches!(Box::as_ref(&elem), &syn::Type::Path(syn::TypePath { ref path, .. }) if path.get_ident() == Some(¶m.ident)) => { + true + } + _ => false, + }; + if input.generics.type_params().find(|x| matcher(x)).is_none() { + continue; + } + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); + } + } + let arms = input.variants.iter().map(|variant| { let display = match &variant.attrs.display { Some(display) => display.to_token_stream(), @@ -302,6 +346,9 @@ fn impl_enum(input: Enum) -> TokenStream { #ty::#ident #pat => #display } }); + + let where_clause = augment_where_clause(where_clause, extra_predicates); + Some(quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { @@ -343,6 +390,10 @@ fn impl_enum(input: Enum) -> TokenStream { let error_trait = spanned_error_trait(input.original); + let where_clause = augment_where_clause( + where_clause, + [parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }], + ); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { @@ -354,6 +405,35 @@ fn impl_enum(input: Enum) -> TokenStream { } } +fn augment_where_clause( + where_clause: Option<&syn::WhereClause>, + extra_predicates: TPredicates, +) -> Option +where + TPredicates: IntoIterator, + TPredicates::IntoIter: std::iter::ExactSizeIterator, +{ + let extra_predicates = extra_predicates.into_iter(); + if extra_predicates.len() == 0 { + return where_clause.cloned(); + } + Some(match where_clause { + Some(w) => syn::WhereClause { + where_token: w.where_token.clone(), + predicates: w + .predicates + .iter() + .cloned() + .chain(extra_predicates) + .collect(), + }, + None => syn::WhereClause { + where_token: syn::token::Where::default(), + predicates: extra_predicates.into_iter().collect(), + }, + }) +} + fn fields_pat(fields: &[Field]) -> TokenStream { let mut members = fields.iter().map(|field| &field.member).peekable(); match members.peek() { diff --git a/tests/test_generics.rs b/tests/test_generics.rs new file mode 100644 index 0000000..000c30c --- /dev/null +++ b/tests/test_generics.rs @@ -0,0 +1,112 @@ +#![deny(clippy::all, clippy::pedantic)] +#![allow( + // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422 + clippy::nonstandard_macro_braces, +)] +#![allow(dead_code)] + +use std::fmt::{Debug, Display}; + +use thiserror::Error; + +struct NoFormattingType; + +struct DisplayType; + +impl Display for DisplayType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DisplayType)) + } +} + +impl Debug for DisplayType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DisplayType)) + } +} + +struct DebugType; + +impl Debug for DebugType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DebugType)) + } +} + +/// Direct embedding of a generic in a field +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for DirectEmbedding +/// where +/// Embedded: Debug; +/// +/// impl Error for DirectEmbedding +/// where +/// Self: Debug + Display; +/// ``` +#[derive(Error, Debug)] +enum DirectEmbedding { + #[error("{0:?}")] + FatalError(Embedded), +} + +/// #[from] handling but no Debug usage of the generic +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for FromGenericError; +/// +/// impl Error for FromGenericError +/// where +/// SourceGenericError: Error, +/// Indirect: 'static, +/// Self: Debug + Display; +/// ``` +// TODO: Parse generic parameters contained in #[from] usages and add them to Error impl +// #[derive(Error, Debug)] +// enum FromGenericError { +// #[error("Tadah")] +// SourceEmbedded(#[from] DirectEmbedding), +// } + +/// Direct embedding of a generic in a field +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for DirectEmbedding +/// where +/// HasDisplay: Display, +/// HasDebug: Debug; +/// +/// impl Error for DirectEmbedding +/// where +/// Self: Debug + Display; +/// ``` +#[derive(Error)] +enum HybridDisplayType { + #[error("{0} : {1:?}")] + HybridDisplayCase(HasDisplay, HasDebug), + #[error("{0}")] + DisplayCase(HasDisplay, HasNeither), + #[error("{1:?}")] + DebugCase(HasNeither, HasDebug), +} + +impl Debug + for HybridDisplayType +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "HybridDisplayType") + } +} + +fn display_hybrid_display_type( + instance: HybridDisplayType, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + Debug::fmt(&instance, f) +} From fedb0204254419bf6ed64818e854ae05f023ea7d Mon Sep 17 00:00:00 2001 From: Zoey Date: Sat, 28 Aug 2021 02:34:56 -0700 Subject: [PATCH 2/7] Add visitor which marks provided generics by name in variants --- impl/Cargo.toml | 2 +- impl/src/expand.rs | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/impl/Cargo.toml b/impl/Cargo.toml index 39bbc3d..8b54bad 100644 --- a/impl/Cargo.toml +++ b/impl/Cargo.toml @@ -13,7 +13,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = "1.0.45" +syn = { version = "1.0.45", features = ["visit"] } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 58a7293..aedbcab 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -306,6 +306,7 @@ fn impl_enum(input: Enum) -> TokenStream { DisplayFormatMarking::Display(f) => { ty = &f.ty; bound = parse_quote! { ::std::fmt::Display }; + // f.original. } } let matcher = |param: &&syn::TypeParam| match ty { @@ -405,6 +406,73 @@ fn impl_enum(input: Enum) -> TokenStream { } } +#[cfg_attr(test, derive(Debug))] +struct GenericUsageVisitor { + generics: std::collections::HashMap, +} + +impl GenericUsageVisitor { + pub fn new(generics: TPairs) -> Self + where + TPairs: IntoIterator, + { + Self { + generics: generics.into_iter().collect(), + } + } +} + +impl<'ast> syn::visit::Visit<'ast> for GenericUsageVisitor { + fn visit_type_path(&mut self, i: &'ast syn::TypePath) { + if let Some(ident) = i.path.get_ident() { + if let Some(entry) = self.generics.get_mut(ident) { + *entry = true; + } + } + syn::visit::visit_type_path(self, i) + } + + fn visit_type_param(&mut self, i: &'ast syn::TypeParam) { + if let Some(entry) = self.generics.get_mut(&i.ident) { + *entry = true; + } + syn::visit::visit_type_param(self, i) + } +} + +#[cfg(test)] +mod test_visitors { + use proc_macro2::Span; + + use crate::expand::GenericUsageVisitor; + + #[test] + fn test_generic_usage_visitor() { + fn ident_for<'a>(ident_str: &'a str) -> syn::Ident { + syn::Ident::new(ident_str, Span::call_site()) + } + + let field: syn::Variant = syn::parse_quote! { X(Foo) }; + + let mut visitor = GenericUsageVisitor::new( + vec!["Bar", "Baz"] + .into_iter() + .map(|ident_str| (ident_for(ident_str), false)), + ); + syn::visit::visit_variant(&mut visitor, &field); + assert_eq!( + visitor.generics.get(&ident_for("Bar")), + Some(&true), + "Bar must be marked" + ); + assert_eq!( + visitor.generics.get(&ident_for("Baz")), + Some(&false), + "Baz must be present but not marked" + ); + } +} + fn augment_where_clause( where_clause: Option<&syn::WhereClause>, extra_predicates: TPredicates, From 64b2a0d1fb700d59e3d90f9f6cd4090dfcf8be70 Mon Sep 17 00:00:00 2001 From: Zoey Date: Sat, 28 Aug 2021 19:51:14 -0700 Subject: [PATCH 3/7] Support `#[from]` usage parsing using GenericUsageVisitor --- impl/src/expand.rs | 126 ++++++++++++++++++++++++++++++----------- tests/test_generics.rs | 13 ++--- 2 files changed, 100 insertions(+), 39 deletions(-) diff --git a/impl/src/expand.rs b/impl/src/expand.rs index aedbcab..b8fe533 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -294,39 +294,37 @@ fn impl_enum(input: Enum) -> TokenStream { { for field in display.iter_fmt_types(fields) { use crate::attr::DisplayFormatMarking; - let (ty, bound): ( + let (ty, bound, ast): ( &syn::Type, syn::punctuated::Punctuated, + &syn::Field, ); match field { DisplayFormatMarking::Debug(f) => { - ty = &f.ty; + ty = f.ty; bound = parse_quote! { ::std::fmt::Debug }; + ast = f.original; } DisplayFormatMarking::Display(f) => { - ty = &f.ty; + ty = f.ty; bound = parse_quote! { ::std::fmt::Display }; - // f.original. + ast = f.original; } } - let matcher = |param: &&syn::TypeParam| match ty { - syn::Type::Path(syn::TypePath { path, .. }) => { - path.get_ident() == Some(¶m.ident) - } - syn::Type::Reference(syn::TypeReference { elem, .. }) if matches!(Box::as_ref(&elem), &syn::Type::Path(syn::TypePath { ref path, .. }) if path.get_ident() == Some(¶m.ident)) => { - true - } - _ => false, - }; - if input.generics.type_params().find(|x| matcher(x)).is_none() { - continue; + // If a generic is at all present, a constraint will be applied to the field type + // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly + let mut usages = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_field(&mut usages, ast); + if usages.iter_marked().next().is_some() { + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); } - extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { - bounded_ty: ty.clone(), - colon_token: syn::token::Colon::default(), - bounds: bound, - lifetimes: None, - })); } } @@ -389,12 +387,46 @@ fn impl_enum(input: Enum) -> TokenStream { }) }); + let (generic_field_types, generics_in_from_types): (Vec<&syn::Type>, Vec) = { + let mut generics_in_from_types = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + let mut generic_field_types = Vec::new(); + for from_field in input + .variants + .iter() + .filter_map(|variant| variant.from_field()) + { + let mut generics_in_this_field = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_type(&mut generics_in_this_field, &from_field.original.ty); + if generics_in_from_types.mark_from(generics_in_this_field.iter_marked()) { + generic_field_types.push(from_field.ty); + } + } + ( + generic_field_types, + generics_in_from_types.generics.into_keys().collect(), + ) + }; + let error_trait = spanned_error_trait(input.original); - let where_clause = augment_where_clause( - where_clause, - [parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }], - ); + let extra_predicates = + std::iter::once(parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }) + .chain( + generics_in_from_types + .into_iter() + .map(|generic| parse_quote! { #generic: 'static }), + ) + .chain( + generic_field_types + .into_iter() + .map(|ty| parse_quote! { #ty: ::std::error::Error }), + ); + + let where_clause = augment_where_clause(where_clause, extra_predicates); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { @@ -407,8 +439,9 @@ fn impl_enum(input: Enum) -> TokenStream { } #[cfg_attr(test, derive(Debug))] +#[derive(Clone)] struct GenericUsageVisitor { - generics: std::collections::HashMap, + generics: std::collections::HashMap, } impl GenericUsageVisitor { @@ -420,6 +453,36 @@ impl GenericUsageVisitor { generics: generics.into_iter().collect(), } } + + pub fn new_unmarked(generics: TIdents) -> Self + where + TIdents: IntoIterator, + { + Self::new(generics.into_iter().map(|ident| (ident, false))) + } + + pub fn iter_marked(&self) -> impl Iterator { + self.generics.iter().filter(|(_k, v)| **v).map(|(k, _v)| k) + } + + #[allow(dead_code)] + pub fn is_marked<'a, T: PartialEq<&'a proc_macro2::Ident> + 'a>(&'a self, item: T) -> bool { + self.iter_marked().any(|ident| item == ident) + } + + pub fn mark_from<'a, TMarkedSource: IntoIterator + 'a>( + &'a mut self, + source: TMarkedSource, + ) -> bool { + let mut any_marked = false; + for other_marked in source { + if let Some(is_marked) = self.generics.get_mut(other_marked) { + *is_marked = true; + any_marked = true; + } + } + any_marked + } } impl<'ast> syn::visit::Visit<'ast> for GenericUsageVisitor { @@ -429,14 +492,14 @@ impl<'ast> syn::visit::Visit<'ast> for GenericUsageVisitor { *entry = true; } } - syn::visit::visit_type_path(self, i) + syn::visit::visit_type_path(self, i); } fn visit_type_param(&mut self, i: &'ast syn::TypeParam) { if let Some(entry) = self.generics.get_mut(&i.ident) { *entry = true; } - syn::visit::visit_type_param(self, i) + syn::visit::visit_type_param(self, i); } } @@ -479,15 +542,14 @@ fn augment_where_clause( ) -> Option where TPredicates: IntoIterator, - TPredicates::IntoIter: std::iter::ExactSizeIterator, { - let extra_predicates = extra_predicates.into_iter(); - if extra_predicates.len() == 0 { + let mut extra_predicates = extra_predicates.into_iter().peekable(); + if extra_predicates.peek().is_none() { return where_clause.cloned(); } Some(match where_clause { Some(w) => syn::WhereClause { - where_token: w.where_token.clone(), + where_token: w.where_token, predicates: w .predicates .iter() diff --git a/tests/test_generics.rs b/tests/test_generics.rs index 000c30c..ef30d52 100644 --- a/tests/test_generics.rs +++ b/tests/test_generics.rs @@ -61,16 +61,15 @@ enum DirectEmbedding { /// /// impl Error for FromGenericError /// where -/// SourceGenericError: Error, +/// DirectEmbedding: Error, /// Indirect: 'static, /// Self: Debug + Display; /// ``` -// TODO: Parse generic parameters contained in #[from] usages and add them to Error impl -// #[derive(Error, Debug)] -// enum FromGenericError { -// #[error("Tadah")] -// SourceEmbedded(#[from] DirectEmbedding), -// } +#[derive(Error, Debug)] +enum FromGenericError { + #[error("Tadah")] + SourceEmbedded(#[from] DirectEmbedding), +} /// Direct embedding of a generic in a field /// From f9ca4561437f795211267c0bd4fd2e6208c2fa2f Mon Sep 17 00:00:00 2001 From: Zoey Date: Sat, 28 Aug 2021 21:24:57 -0700 Subject: [PATCH 4/7] Add support for generic structs --- impl/src/expand.rs | 154 ++++++++++++++++++++++++++++++----------- tests/test_generics.rs | 21 +++++- 2 files changed, 135 insertions(+), 40 deletions(-) diff --git a/impl/src/expand.rs b/impl/src/expand.rs index b8fe533..41a951c 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -112,6 +112,49 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let display_impl = display_body.map(|body| { + let mut extra_predicates = Vec::new(); + for field in input + .attrs + .display + .iter() + .flat_map(|d| d.iter_fmt_types(input.fields.as_slice())) + { + use crate::attr::DisplayFormatMarking; + let (ty, bound, ast): ( + &syn::Type, + syn::punctuated::Punctuated, + &syn::Field, + ); + match field { + DisplayFormatMarking::Debug(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Debug }; + ast = f.original; + } + DisplayFormatMarking::Display(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Display }; + ast = f.original; + } + } + // If a generic is at all present, a constraint will be applied to the field type + // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly + let mut usages = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_field(&mut usages, ast); + if usages.iter_marked().next().is_some() { + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); + } + } + + let where_clause = augment_where_clause(where_clause, extra_predicates); + quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { @@ -142,8 +185,42 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let error_trait = spanned_error_trait(input.original); + let (generic_field_types, generics_in_from_types): (Vec<&syn::Type>, Vec) = { + let mut generics_in_from_types = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + let mut generic_field_types = Vec::new(); + if let Some(from_field) = input.from_field() { + let mut generics_in_this_field = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_type(&mut generics_in_this_field, &from_field.original.ty); + if generics_in_from_types.mark_from(generics_in_this_field.iter_marked()) { + generic_field_types.push(from_field.ty); + } + } + ( + generic_field_types, + generics_in_from_types.generics.into_keys().collect(), + ) + }; + let extra_predicates = + std::iter::once(parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }) + .chain( + generics_in_from_types + .into_iter() + .map(|generic| parse_quote! { #generic: 'static }), + ) + .chain( + generic_field_types + .into_iter() + .map(|ty| parse_quote! { #ty: ::std::error::Error }), + ); + + let where_clause = augment_where_clause(where_clause, extra_predicates); + + let error_trait = spanned_error_trait(input.original); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { @@ -287,45 +364,44 @@ fn impl_enum(input: Enum) -> TokenStream { }; let mut extra_predicates = Vec::new(); - for (fields, display) in input - .variants - .iter() - .filter_map(|v| v.attrs.display.as_ref().map(|d| (v.fields.as_slice(), d))) - { - for field in display.iter_fmt_types(fields) { - use crate::attr::DisplayFormatMarking; - let (ty, bound, ast): ( - &syn::Type, - syn::punctuated::Punctuated, - &syn::Field, - ); - match field { - DisplayFormatMarking::Debug(f) => { - ty = f.ty; - bound = parse_quote! { ::std::fmt::Debug }; - ast = f.original; - } - DisplayFormatMarking::Display(f) => { - ty = f.ty; - bound = parse_quote! { ::std::fmt::Display }; - ast = f.original; - } + for field in input.variants.iter().flat_map(|v| { + v.attrs + .display + .iter() + .flat_map(move |d| d.iter_fmt_types(&v.fields)) + }) { + use crate::attr::DisplayFormatMarking; + let (ty, bound, ast): ( + &syn::Type, + syn::punctuated::Punctuated, + &syn::Field, + ); + match field { + DisplayFormatMarking::Debug(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Debug }; + ast = f.original; } - // If a generic is at all present, a constraint will be applied to the field type - // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly - let mut usages = GenericUsageVisitor::new_unmarked( - input.generics.type_params().map(|p| p.ident.clone()), - ); - syn::visit::visit_field(&mut usages, ast); - if usages.iter_marked().next().is_some() { - extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { - bounded_ty: ty.clone(), - colon_token: syn::token::Colon::default(), - bounds: bound, - lifetimes: None, - })); + DisplayFormatMarking::Display(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Display }; + ast = f.original; } } + // If a generic is at all present, a constraint will be applied to the field type + // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly + let mut usages = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_field(&mut usages, ast); + if usages.iter_marked().next().is_some() { + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); + } } let arms = input.variants.iter().map(|variant| { @@ -411,8 +487,6 @@ fn impl_enum(input: Enum) -> TokenStream { ) }; - let error_trait = spanned_error_trait(input.original); - let extra_predicates = std::iter::once(parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }) .chain( @@ -427,6 +501,8 @@ fn impl_enum(input: Enum) -> TokenStream { ); let where_clause = augment_where_clause(where_clause, extra_predicates); + + let error_trait = spanned_error_trait(input.original); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { diff --git a/tests/test_generics.rs b/tests/test_generics.rs index ef30d52..bb67c4b 100644 --- a/tests/test_generics.rs +++ b/tests/test_generics.rs @@ -99,7 +99,7 @@ impl Debug for HybridDisplayType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "HybridDisplayType") + write!(f, stringify!(HybridDisplayType)) } } @@ -109,3 +109,22 @@ fn display_hybrid_display_type( ) -> std::fmt::Result { Debug::fmt(&instance, f) } + +#[derive(Error, Debug)] +#[error("{0:?}")] +struct DirectEmbeddingStructTuple(Embedded); + +#[derive(Error, Debug)] +#[error("{direct:?}")] +struct DirectEmbeddingStructNominal { + direct: Embedded, +} + +#[derive(Error, Debug)] +struct FromGenericErrorStructTuple(#[from] DirectEmbedding); + +#[derive(Error, Debug)] +struct FromGenericErrorStructNominal { + #[from] + indirect: DirectEmbedding, +} From 28751e051b89576068f1c8d7a4865b91fb8d63de Mon Sep 17 00:00:00 2001 From: Zoey Date: Mon, 30 Aug 2021 03:31:42 -0700 Subject: [PATCH 5/7] Fully functional. Format literal parsing is present yet atrocious. --- impl/src/attr.rs | 20 ----- impl/src/expand.rs | 3 +- impl/src/fmt.rs | 169 +++++++++++++++++++++++++++++++++++++++-- tests/test_generics.rs | 8 +- 4 files changed, 166 insertions(+), 34 deletions(-) diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 82f3d98..1ab1e28 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -15,12 +15,6 @@ pub struct Attrs<'a> { pub transparent: Option>, } -#[derive(Clone, Copy)] -pub enum DisplayFormatMarking<'a> { - Debug(&'a crate::ast::Field<'a>), - Display(&'a crate::ast::Field<'a>), -} - #[derive(Clone)] pub struct Display<'a> { pub original: &'a Attribute, @@ -29,20 +23,6 @@ pub struct Display<'a> { pub has_bonus_display: bool, } -impl<'a> Display<'a> { - pub fn iter_fmt_types( - &'a self, - fields: &'a [crate::ast::Field], - ) -> impl Iterator + 'a { - // TODO: Parse format string literal, return only fields at offsets which are Display or Debug referenced by each format reference - // If a field position is referenced multiple times for the same format class, deduplicate - fields - .iter() - .map(DisplayFormatMarking::Display) - .chain(fields.iter().map(DisplayFormatMarking::Debug)) - } -} - #[derive(Copy, Clone)] pub struct Transparent<'a> { pub original: &'a Attribute, diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 41a951c..4907c61 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,4 +1,5 @@ use crate::ast::{Enum, Field, Input, Struct}; +use crate::fmt::DisplayFormatMarking; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; @@ -119,7 +120,6 @@ fn impl_struct(input: Struct) -> TokenStream { .iter() .flat_map(|d| d.iter_fmt_types(input.fields.as_slice())) { - use crate::attr::DisplayFormatMarking; let (ty, bound, ast): ( &syn::Type, syn::punctuated::Punctuated, @@ -370,7 +370,6 @@ fn impl_enum(input: Enum) -> TokenStream { .iter() .flat_map(move |d| d.iter_fmt_types(&v.fields)) }) { - use crate::attr::DisplayFormatMarking; let (ty, bound, ast): ( &syn::Type, syn::punctuated::Punctuated, diff --git a/impl/src/fmt.rs b/impl/src/fmt.rs index e12e94b..25b4595 100644 --- a/impl/src/fmt.rs +++ b/impl/src/fmt.rs @@ -1,13 +1,20 @@ -use crate::ast::Field; -use crate::attr::Display; -use proc_macro2::TokenTree; +use crate::{ast::Field, attr::Display}; +use proc_macro2::{Span, TokenTree}; use quote::{format_ident, quote_spanned}; -use std::collections::HashSet as Set; -use syn::ext::IdentExt; -use syn::parse::{ParseStream, Parser}; -use syn::{Ident, Index, LitStr, Member, Result, Token}; +use std::{collections::HashSet as Set, convert::TryInto}; +use syn::{ + ext::IdentExt, + parse::{ParseStream, Parser}, + Ident, Index, LitStr, Member, Result, Token, +}; -impl Display<'_> { +#[derive(Clone, Copy)] +pub enum DisplayFormatMarking<'a> { + Debug(&'a crate::ast::Field<'a>), + Display(&'a crate::ast::Field<'a>), +} + +impl<'a> Display<'a> { // Transform `"error {var}"` to `"error {}", var`. pub fn expand_shorthand(&mut self, fields: &[Field]) { let raw_args = self.args.clone(); @@ -94,6 +101,122 @@ impl Display<'_> { self.args = args; self.has_bonus_display = has_bonus_display; } + + pub fn iter_fmt_types(&'a self, fields: &'a [Field]) -> Vec> { + let members: Set = fields.iter().map(|f| f.member.clone()).collect(); + let fmt = self.fmt.value(); + let read = fmt.as_str(); + + let mut member_refs: Vec<(&Member, &Field, bool)> = Vec::new(); + + for template in parse_fmt_template(read) { + if let Some(target) = &template.target { + if members.contains(target) { + if let Some(matching) = fields.iter().find(|f| &f.member == target) { + member_refs.push((&matching.member, matching, template.is_display())); + } + } + } + } + + member_refs + .iter() + .map(|(_m, f, is_display)| { + if *is_display { + DisplayFormatMarking::Display(*f) + } else { + DisplayFormatMarking::Debug(*f) + } + }) + .collect() + } +} + +struct FormatInterpolation { + pub target: Option, + pub format: Option, +} + +impl FormatInterpolation { + pub fn is_debug(&self) -> bool { + self.format + .as_ref() + .map(|x| x.contains('?')) + .unwrap_or(false) + } + + pub fn is_display(&self) -> bool { + !self.is_debug() + } +} + +impl From<(Option<&str>, Option<&str>)> for FormatInterpolation { + fn from((target, style): (Option<&str>, Option<&str>)) -> Self { + let target = match target { + None => None, + Some(s) => Some(if let std::result::Result::::Ok(i) = s.parse() { + Member::Unnamed(syn::Index { + index: i.try_into().unwrap(), + span: Span::call_site(), + }) + } else { + let mut s = s; + let ident = take_ident(&mut s); + Member::Named(ident) + }), + }; + let format = style.map(String::from); + FormatInterpolation { target, format } + } +} + +fn read_format_template(mut read: &str) -> Option<(FormatInterpolation, &str)> { + // If we aren't in a bracketed area, or we are in an escaped bracket, return None + if !read.starts_with('{') || read.starts_with("{{") { + return None; + } + // Read past the starting bracket + read = &read[1..]; + // If there is no end bracket, bail + let end_bracket = read.find('}')?; + let contents = &read[..end_bracket]; + let (name, style) = if let Some(colon) = contents.find(':') { + (&contents[..colon], &contents[colon + 1..]) + } else { + (contents, "") + }; + + // Strip expanded identifier-prefixes since we just want the non-shorthand version + let name = if name.starts_with("field_") { + &name["field_".len()..] + } else if name.starts_with("r_") { + &name["r_".len()..] + } else { + name + }; + let name = if name.starts_with('_') { + &name["_".len()..] + } else { + name + }; + + let name = if name.is_empty() { None } else { Some(name) }; + let style = if style.is_empty() { None } else { Some(style) }; + Some(((name, style).into(), &read[end_bracket + 1..])) +} + +fn parse_fmt_template(mut read: &str) -> Vec { + let mut output = Vec::new(); + // From each "{", try reading a template; double-bracket escape handling is done by the template reader + while let Some(opening_bracket) = read.find('{') { + read = &read[opening_bracket..]; + if let Some((template, next)) = read_format_template(read) { + read = next; + output.push(template); + } + read = &read[read.char_indices().nth(1).map(|(x, _)| x).unwrap_or(0)..]; + } + output } fn explicit_named_args(input: ParseStream) -> Result> { @@ -145,3 +268,33 @@ fn take_ident(read: &mut &str) -> Ident { } Ident::parse_any.parse_str(&ident).unwrap() } + +#[cfg(test)] +mod tests { + use quote::ToTokens; + use syn::Member; + + use super::{parse_fmt_template, FormatInterpolation}; + + #[test] + fn parse_and_emit_format_strings() { + let test_str = "\"hello world {{{:} {x:#?} {1} {2:}\""; + let template_groups = parse_fmt_template(test_str); + assert!(matches!( + &template_groups[0], + FormatInterpolation { + target: None, + format: None + } + )); + assert!( + matches!(&template_groups[1], FormatInterpolation { target: Some(Member::Named(x)), format: Some(fmt) } if x.to_token_stream().to_string() == "x" && fmt == "#?") + ); + assert!( + matches!(&template_groups[2], FormatInterpolation { target: Some(Member::Unnamed(idx)), format: None } if idx.index == 1) + ); + assert!( + matches!(&template_groups[3], FormatInterpolation { target: Some(Member::Unnamed(idx)), format: None } if idx.index == 2) + ); + } +} diff --git a/tests/test_generics.rs b/tests/test_generics.rs index bb67c4b..8a37f36 100644 --- a/tests/test_generics.rs +++ b/tests/test_generics.rs @@ -88,11 +88,11 @@ enum FromGenericError { #[derive(Error)] enum HybridDisplayType { #[error("{0} : {1:?}")] - HybridDisplayCase(HasDisplay, HasDebug), + HybridDisplay(HasDisplay, HasDebug), #[error("{0}")] - DisplayCase(HasDisplay, HasNeither), + Display(HasDisplay, HasNeither), #[error("{1:?}")] - DebugCase(HasNeither, HasDebug), + Debug(HasNeither, HasDebug), } impl Debug @@ -104,7 +104,7 @@ impl Debug } fn display_hybrid_display_type( - instance: HybridDisplayType, + instance: &HybridDisplayType, f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { Debug::fmt(&instance, f) From 5de95890b3821665a6c3e3c73a11eccafd132d0c Mon Sep 17 00:00:00 2001 From: Zoey Date: Mon, 30 Aug 2021 03:46:14 -0700 Subject: [PATCH 6/7] Fix compatibility with rust 1.31.0 and 1.36.0 --- impl/src/expand.rs | 17 ++++++++++++++--- impl/src/fmt.rs | 6 +++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 871fd2d..a0f933b 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -3,7 +3,10 @@ use crate::fmt::DisplayFormatMarking; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility}; +use syn::{ + parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, + Visibility, +}; pub fn derive(node: &DeriveInput) -> Result { let input = Input::from_syn(node)?; @@ -203,7 +206,11 @@ fn impl_struct(input: Struct) -> TokenStream { } ( generic_field_types, - generics_in_from_types.generics.into_keys().collect(), + generics_in_from_types + .generics + .into_iter() + .map(|(k, _v)| k) + .collect(), ) }; @@ -505,7 +512,11 @@ fn impl_enum(input: Enum) -> TokenStream { } ( generic_field_types, - generics_in_from_types.generics.into_keys().collect(), + generics_in_from_types + .generics + .into_iter() + .map(|(k, _v)| k) + .collect(), ) }; diff --git a/impl/src/fmt.rs b/impl/src/fmt.rs index 25b4595..b9e858f 100644 --- a/impl/src/fmt.rs +++ b/impl/src/fmt.rs @@ -1,7 +1,7 @@ use crate::{ast::Field, attr::Display}; use proc_macro2::{Span, TokenTree}; use quote::{format_ident, quote_spanned}; -use std::{collections::HashSet as Set, convert::TryInto}; +use std::collections::HashSet as Set; use syn::{ ext::IdentExt, parse::{ParseStream, Parser}, @@ -154,9 +154,9 @@ impl From<(Option<&str>, Option<&str>)> for FormatInterpolation { fn from((target, style): (Option<&str>, Option<&str>)) -> Self { let target = match target { None => None, - Some(s) => Some(if let std::result::Result::::Ok(i) = s.parse() { + Some(s) => Some(if let Ok(i) = s.parse::() { Member::Unnamed(syn::Index { - index: i.try_into().unwrap(), + index: i as u32, span: Span::call_site(), }) } else { From 377de4811778b1e59cdfd242f8387dee10a3c1cb Mon Sep 17 00:00:00 2001 From: Zoey Date: Mon, 30 Aug 2021 03:53:37 -0700 Subject: [PATCH 7/7] Apparently `matches` is newer than Rust 1.36.0 --- impl/src/fmt.rs | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/impl/src/fmt.rs b/impl/src/fmt.rs index b9e858f..a82ee45 100644 --- a/impl/src/fmt.rs +++ b/impl/src/fmt.rs @@ -280,21 +280,33 @@ mod tests { fn parse_and_emit_format_strings() { let test_str = "\"hello world {{{:} {x:#?} {1} {2:}\""; let template_groups = parse_fmt_template(test_str); - assert!(matches!( - &template_groups[0], + assert!(match &template_groups[0] { FormatInterpolation { target: None, - format: None - } - )); - assert!( - matches!(&template_groups[1], FormatInterpolation { target: Some(Member::Named(x)), format: Some(fmt) } if x.to_token_stream().to_string() == "x" && fmt == "#?") - ); - assert!( - matches!(&template_groups[2], FormatInterpolation { target: Some(Member::Unnamed(idx)), format: None } if idx.index == 1) - ); - assert!( - matches!(&template_groups[3], FormatInterpolation { target: Some(Member::Unnamed(idx)), format: None } if idx.index == 2) - ); + format: None, + } => true, + _ => false, + }); + assert!(match &template_groups[1] { + FormatInterpolation { + target: Some(Member::Named(x)), + format: Some(fmt), + } if x.to_token_stream().to_string() == "x" && fmt == "#?" => true, + _ => false, + }); + assert!(match &template_groups[2] { + FormatInterpolation { + target: Some(Member::Unnamed(idx)), + format: None, + } if idx.index == 1 => true, + _ => false, + }); + assert!(match &template_groups[3] { + FormatInterpolation { + target: Some(Member::Unnamed(idx)), + format: None, + } if idx.index == 2 => true, + _ => false, + }); } }