diff --git a/impl/Cargo.toml b/impl/Cargo.toml index 4965078..1a3ff87 100644 --- a/impl/Cargo.toml +++ b/impl/Cargo.toml @@ -13,7 +13,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = "1.0.45" +syn = { version = "1.0.45", features = ["visit"] } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 5855128..a0f933b 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,8 +1,12 @@ use crate::ast::{Enum, Field, Input, Struct}; +use crate::fmt::DisplayFormatMarking; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, Visibility}; +use syn::{ + parse_quote, Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Type, + Visibility, +}; pub fn derive(node: &DeriveInput) -> Result { let input = Input::from_syn(node)?; @@ -114,6 +118,48 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let display_impl = display_body.map(|body| { + let mut extra_predicates = Vec::new(); + for field in input + .attrs + .display + .iter() + .flat_map(|d| d.iter_fmt_types(input.fields.as_slice())) + { + let (ty, bound, ast): ( + &syn::Type, + syn::punctuated::Punctuated, + &syn::Field, + ); + match field { + DisplayFormatMarking::Debug(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Debug }; + ast = f.original; + } + DisplayFormatMarking::Display(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Display }; + ast = f.original; + } + } + // If a generic is at all present, a constraint will be applied to the field type + // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly + let mut usages = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_field(&mut usages, ast); + if usages.iter_marked().next().is_some() { + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); + } + } + + let where_clause = augment_where_clause(where_clause, extra_predicates); + quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { @@ -144,8 +190,46 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let error_trait = spanned_error_trait(input.original); + let (generic_field_types, generics_in_from_types): (Vec<&syn::Type>, Vec) = { + let mut generics_in_from_types = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + let mut generic_field_types = Vec::new(); + if let Some(from_field) = input.from_field() { + let mut generics_in_this_field = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_type(&mut generics_in_this_field, &from_field.original.ty); + if generics_in_from_types.mark_from(generics_in_this_field.iter_marked()) { + generic_field_types.push(from_field.ty); + } + } + ( + generic_field_types, + generics_in_from_types + .generics + .into_iter() + .map(|(k, _v)| k) + .collect(), + ) + }; + + let extra_predicates = + std::iter::once(parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }) + .chain( + generics_in_from_types + .into_iter() + .map(|generic| parse_quote! { #generic: 'static }), + ) + .chain( + generic_field_types + .into_iter() + .map(|ty| parse_quote! { #ty: ::std::error::Error }), + ); + let where_clause = augment_where_clause(where_clause, extra_predicates); + + let error_trait = spanned_error_trait(input.original); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { @@ -308,6 +392,47 @@ fn impl_enum(input: Enum) -> TokenStream { } else { None }; + + let mut extra_predicates = Vec::new(); + for field in input.variants.iter().flat_map(|v| { + v.attrs + .display + .iter() + .flat_map(move |d| d.iter_fmt_types(&v.fields)) + }) { + let (ty, bound, ast): ( + &syn::Type, + syn::punctuated::Punctuated, + &syn::Field, + ); + match field { + DisplayFormatMarking::Debug(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Debug }; + ast = f.original; + } + DisplayFormatMarking::Display(f) => { + ty = f.ty; + bound = parse_quote! { ::std::fmt::Display }; + ast = f.original; + } + } + // If a generic is at all present, a constraint will be applied to the field type + // This may create redundant `AlwaysDebug: Debug` scenarios, but covers T: Debug and &T: Debug cleanly + let mut usages = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_field(&mut usages, ast); + if usages.iter_marked().next().is_some() { + extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType { + bounded_ty: ty.clone(), + colon_token: syn::token::Colon::default(), + bounds: bound, + lifetimes: None, + })); + } + } + let arms = input.variants.iter().map(|variant| { let display = match &variant.attrs.display { Some(display) => display.to_token_stream(), @@ -325,6 +450,9 @@ fn impl_enum(input: Enum) -> TokenStream { #ty::#ident #pat => #display } }); + + let where_clause = augment_where_clause(where_clause, extra_predicates); + Some(quote! { #[allow(unused_qualifications)] impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { @@ -364,8 +492,50 @@ fn impl_enum(input: Enum) -> TokenStream { }) }); - let error_trait = spanned_error_trait(input.original); + let (generic_field_types, generics_in_from_types): (Vec<&syn::Type>, Vec) = { + let mut generics_in_from_types = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + let mut generic_field_types = Vec::new(); + for from_field in input + .variants + .iter() + .filter_map(|variant| variant.from_field()) + { + let mut generics_in_this_field = GenericUsageVisitor::new_unmarked( + input.generics.type_params().map(|p| p.ident.clone()), + ); + syn::visit::visit_type(&mut generics_in_this_field, &from_field.original.ty); + if generics_in_from_types.mark_from(generics_in_this_field.iter_marked()) { + generic_field_types.push(from_field.ty); + } + } + ( + generic_field_types, + generics_in_from_types + .generics + .into_iter() + .map(|(k, _v)| k) + .collect(), + ) + }; + + let extra_predicates = + std::iter::once(parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }) + .chain( + generics_in_from_types + .into_iter() + .map(|generic| parse_quote! { #generic: 'static }), + ) + .chain( + generic_field_types + .into_iter() + .map(|ty| parse_quote! { #ty: ::std::error::Error }), + ); + let where_clause = augment_where_clause(where_clause, extra_predicates); + + let error_trait = spanned_error_trait(input.original); quote! { #[allow(unused_qualifications)] impl #impl_generics #error_trait for #ty #ty_generics #where_clause { @@ -377,6 +547,132 @@ fn impl_enum(input: Enum) -> TokenStream { } } +#[cfg_attr(test, derive(Debug))] +#[derive(Clone)] +struct GenericUsageVisitor { + generics: std::collections::HashMap, +} + +impl GenericUsageVisitor { + pub fn new(generics: TPairs) -> Self + where + TPairs: IntoIterator, + { + Self { + generics: generics.into_iter().collect(), + } + } + + pub fn new_unmarked(generics: TIdents) -> Self + where + TIdents: IntoIterator, + { + Self::new(generics.into_iter().map(|ident| (ident, false))) + } + + pub fn iter_marked(&self) -> impl Iterator { + self.generics.iter().filter(|(_k, v)| **v).map(|(k, _v)| k) + } + + #[allow(dead_code)] + pub fn is_marked<'a, T: PartialEq<&'a proc_macro2::Ident> + 'a>(&'a self, item: T) -> bool { + self.iter_marked().any(|ident| item == ident) + } + + pub fn mark_from<'a, TMarkedSource: IntoIterator + 'a>( + &'a mut self, + source: TMarkedSource, + ) -> bool { + let mut any_marked = false; + for other_marked in source { + if let Some(is_marked) = self.generics.get_mut(other_marked) { + *is_marked = true; + any_marked = true; + } + } + any_marked + } +} + +impl<'ast> syn::visit::Visit<'ast> for GenericUsageVisitor { + fn visit_type_path(&mut self, i: &'ast syn::TypePath) { + if let Some(ident) = i.path.get_ident() { + if let Some(entry) = self.generics.get_mut(ident) { + *entry = true; + } + } + syn::visit::visit_type_path(self, i); + } + + fn visit_type_param(&mut self, i: &'ast syn::TypeParam) { + if let Some(entry) = self.generics.get_mut(&i.ident) { + *entry = true; + } + syn::visit::visit_type_param(self, i); + } +} + +#[cfg(test)] +mod test_visitors { + use proc_macro2::Span; + + use crate::expand::GenericUsageVisitor; + + #[test] + fn test_generic_usage_visitor() { + fn ident_for<'a>(ident_str: &'a str) -> syn::Ident { + syn::Ident::new(ident_str, Span::call_site()) + } + + let field: syn::Variant = syn::parse_quote! { X(Foo) }; + + let mut visitor = GenericUsageVisitor::new( + vec!["Bar", "Baz"] + .into_iter() + .map(|ident_str| (ident_for(ident_str), false)), + ); + syn::visit::visit_variant(&mut visitor, &field); + assert_eq!( + visitor.generics.get(&ident_for("Bar")), + Some(&true), + "Bar must be marked" + ); + assert_eq!( + visitor.generics.get(&ident_for("Baz")), + Some(&false), + "Baz must be present but not marked" + ); + } +} + +fn augment_where_clause( + where_clause: Option<&syn::WhereClause>, + extra_predicates: TPredicates, +) -> Option +where + TPredicates: IntoIterator, +{ + let mut extra_predicates = extra_predicates.into_iter().peekable(); + if extra_predicates.peek().is_none() { + return where_clause.cloned(); + } + Some(match where_clause { + Some(w) => syn::WhereClause { + where_token: w.where_token, + predicates: w + .predicates + .iter() + .cloned() + .chain(extra_predicates) + .collect(), + }, + None => syn::WhereClause { + where_token: syn::token::Where::default(), + predicates: extra_predicates.into_iter().collect(), + }, + }) +} + fn fields_pat(fields: &[Field]) -> TokenStream { let mut members = fields.iter().map(|field| &field.member).peekable(); match members.peek() { diff --git a/impl/src/fmt.rs b/impl/src/fmt.rs index e12e94b..a82ee45 100644 --- a/impl/src/fmt.rs +++ b/impl/src/fmt.rs @@ -1,13 +1,20 @@ -use crate::ast::Field; -use crate::attr::Display; -use proc_macro2::TokenTree; +use crate::{ast::Field, attr::Display}; +use proc_macro2::{Span, TokenTree}; use quote::{format_ident, quote_spanned}; use std::collections::HashSet as Set; -use syn::ext::IdentExt; -use syn::parse::{ParseStream, Parser}; -use syn::{Ident, Index, LitStr, Member, Result, Token}; +use syn::{ + ext::IdentExt, + parse::{ParseStream, Parser}, + Ident, Index, LitStr, Member, Result, Token, +}; -impl Display<'_> { +#[derive(Clone, Copy)] +pub enum DisplayFormatMarking<'a> { + Debug(&'a crate::ast::Field<'a>), + Display(&'a crate::ast::Field<'a>), +} + +impl<'a> Display<'a> { // Transform `"error {var}"` to `"error {}", var`. pub fn expand_shorthand(&mut self, fields: &[Field]) { let raw_args = self.args.clone(); @@ -94,6 +101,122 @@ impl Display<'_> { self.args = args; self.has_bonus_display = has_bonus_display; } + + pub fn iter_fmt_types(&'a self, fields: &'a [Field]) -> Vec> { + let members: Set = fields.iter().map(|f| f.member.clone()).collect(); + let fmt = self.fmt.value(); + let read = fmt.as_str(); + + let mut member_refs: Vec<(&Member, &Field, bool)> = Vec::new(); + + for template in parse_fmt_template(read) { + if let Some(target) = &template.target { + if members.contains(target) { + if let Some(matching) = fields.iter().find(|f| &f.member == target) { + member_refs.push((&matching.member, matching, template.is_display())); + } + } + } + } + + member_refs + .iter() + .map(|(_m, f, is_display)| { + if *is_display { + DisplayFormatMarking::Display(*f) + } else { + DisplayFormatMarking::Debug(*f) + } + }) + .collect() + } +} + +struct FormatInterpolation { + pub target: Option, + pub format: Option, +} + +impl FormatInterpolation { + pub fn is_debug(&self) -> bool { + self.format + .as_ref() + .map(|x| x.contains('?')) + .unwrap_or(false) + } + + pub fn is_display(&self) -> bool { + !self.is_debug() + } +} + +impl From<(Option<&str>, Option<&str>)> for FormatInterpolation { + fn from((target, style): (Option<&str>, Option<&str>)) -> Self { + let target = match target { + None => None, + Some(s) => Some(if let Ok(i) = s.parse::() { + Member::Unnamed(syn::Index { + index: i as u32, + span: Span::call_site(), + }) + } else { + let mut s = s; + let ident = take_ident(&mut s); + Member::Named(ident) + }), + }; + let format = style.map(String::from); + FormatInterpolation { target, format } + } +} + +fn read_format_template(mut read: &str) -> Option<(FormatInterpolation, &str)> { + // If we aren't in a bracketed area, or we are in an escaped bracket, return None + if !read.starts_with('{') || read.starts_with("{{") { + return None; + } + // Read past the starting bracket + read = &read[1..]; + // If there is no end bracket, bail + let end_bracket = read.find('}')?; + let contents = &read[..end_bracket]; + let (name, style) = if let Some(colon) = contents.find(':') { + (&contents[..colon], &contents[colon + 1..]) + } else { + (contents, "") + }; + + // Strip expanded identifier-prefixes since we just want the non-shorthand version + let name = if name.starts_with("field_") { + &name["field_".len()..] + } else if name.starts_with("r_") { + &name["r_".len()..] + } else { + name + }; + let name = if name.starts_with('_') { + &name["_".len()..] + } else { + name + }; + + let name = if name.is_empty() { None } else { Some(name) }; + let style = if style.is_empty() { None } else { Some(style) }; + Some(((name, style).into(), &read[end_bracket + 1..])) +} + +fn parse_fmt_template(mut read: &str) -> Vec { + let mut output = Vec::new(); + // From each "{", try reading a template; double-bracket escape handling is done by the template reader + while let Some(opening_bracket) = read.find('{') { + read = &read[opening_bracket..]; + if let Some((template, next)) = read_format_template(read) { + read = next; + output.push(template); + } + read = &read[read.char_indices().nth(1).map(|(x, _)| x).unwrap_or(0)..]; + } + output } fn explicit_named_args(input: ParseStream) -> Result> { @@ -145,3 +268,45 @@ fn take_ident(read: &mut &str) -> Ident { } Ident::parse_any.parse_str(&ident).unwrap() } + +#[cfg(test)] +mod tests { + use quote::ToTokens; + use syn::Member; + + use super::{parse_fmt_template, FormatInterpolation}; + + #[test] + fn parse_and_emit_format_strings() { + let test_str = "\"hello world {{{:} {x:#?} {1} {2:}\""; + let template_groups = parse_fmt_template(test_str); + assert!(match &template_groups[0] { + FormatInterpolation { + target: None, + format: None, + } => true, + _ => false, + }); + assert!(match &template_groups[1] { + FormatInterpolation { + target: Some(Member::Named(x)), + format: Some(fmt), + } if x.to_token_stream().to_string() == "x" && fmt == "#?" => true, + _ => false, + }); + assert!(match &template_groups[2] { + FormatInterpolation { + target: Some(Member::Unnamed(idx)), + format: None, + } if idx.index == 1 => true, + _ => false, + }); + assert!(match &template_groups[3] { + FormatInterpolation { + target: Some(Member::Unnamed(idx)), + format: None, + } if idx.index == 2 => true, + _ => false, + }); + } +} diff --git a/tests/test_generics.rs b/tests/test_generics.rs new file mode 100644 index 0000000..8a37f36 --- /dev/null +++ b/tests/test_generics.rs @@ -0,0 +1,130 @@ +#![deny(clippy::all, clippy::pedantic)] +#![allow( + // Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422 + clippy::nonstandard_macro_braces, +)] +#![allow(dead_code)] + +use std::fmt::{Debug, Display}; + +use thiserror::Error; + +struct NoFormattingType; + +struct DisplayType; + +impl Display for DisplayType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DisplayType)) + } +} + +impl Debug for DisplayType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DisplayType)) + } +} + +struct DebugType; + +impl Debug for DebugType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(DebugType)) + } +} + +/// Direct embedding of a generic in a field +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for DirectEmbedding +/// where +/// Embedded: Debug; +/// +/// impl Error for DirectEmbedding +/// where +/// Self: Debug + Display; +/// ``` +#[derive(Error, Debug)] +enum DirectEmbedding { + #[error("{0:?}")] + FatalError(Embedded), +} + +/// #[from] handling but no Debug usage of the generic +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for FromGenericError; +/// +/// impl Error for FromGenericError +/// where +/// DirectEmbedding: Error, +/// Indirect: 'static, +/// Self: Debug + Display; +/// ``` +#[derive(Error, Debug)] +enum FromGenericError { + #[error("Tadah")] + SourceEmbedded(#[from] DirectEmbedding), +} + +/// Direct embedding of a generic in a field +/// +/// Should produce the following instances: +/// +/// ```rust +/// impl Display for DirectEmbedding +/// where +/// HasDisplay: Display, +/// HasDebug: Debug; +/// +/// impl Error for DirectEmbedding +/// where +/// Self: Debug + Display; +/// ``` +#[derive(Error)] +enum HybridDisplayType { + #[error("{0} : {1:?}")] + HybridDisplay(HasDisplay, HasDebug), + #[error("{0}")] + Display(HasDisplay, HasNeither), + #[error("{1:?}")] + Debug(HasNeither, HasDebug), +} + +impl Debug + for HybridDisplayType +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!(HybridDisplayType)) + } +} + +fn display_hybrid_display_type( + instance: &HybridDisplayType, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + Debug::fmt(&instance, f) +} + +#[derive(Error, Debug)] +#[error("{0:?}")] +struct DirectEmbeddingStructTuple(Embedded); + +#[derive(Error, Debug)] +#[error("{direct:?}")] +struct DirectEmbeddingStructNominal { + direct: Embedded, +} + +#[derive(Error, Debug)] +struct FromGenericErrorStructTuple(#[from] DirectEmbedding); + +#[derive(Error, Debug)] +struct FromGenericErrorStructNominal { + #[from] + indirect: DirectEmbedding, +}