Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Generic-Derived Qualifiers #145

Closed
2 changes: 1 addition & 1 deletion impl/Cargo.toml
Expand Up @@ -13,7 +13,7 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = "1.0.45"
syn = { version = "1.0.45", features = ["visit"] }

[package.metadata.docs.rs]
targets = ["x86_64-unknown-linux-gnu"]
20 changes: 20 additions & 0 deletions impl/src/attr.rs
Expand Up @@ -15,6 +15,12 @@ pub struct Attrs<'a> {
pub transparent: Option<Transparent<'a>>,
}

#[derive(Clone, Copy)]
pub enum DisplayFormatMarking<'a> {
Debug(&'a crate::ast::Field<'a>),
Display(&'a crate::ast::Field<'a>),
}

#[derive(Clone)]
pub struct Display<'a> {
pub original: &'a Attribute,
Expand All @@ -23,6 +29,20 @@ pub struct Display<'a> {
pub has_bonus_display: bool,
}

impl<'a> Display<'a> {
pub fn iter_fmt_types(
&'a self,
fields: &'a [crate::ast::Field],
) -> impl Iterator<Item = DisplayFormatMarking> + 'a {
// TODO: Parse format string literal, return only fields at offsets which are Display or Debug referenced by each format reference
// If a field position is referenced multiple times for the same format class, deduplicate
fields
.iter()
.map(DisplayFormatMarking::Display)
.chain(fields.iter().map(DisplayFormatMarking::Debug))
}
}

#[derive(Copy, Clone)]
pub struct Transparent<'a> {
pub original: &'a Attribute,
Expand Down
150 changes: 149 additions & 1 deletion impl/src/expand.rs
Expand Up @@ -2,7 +2,7 @@ use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
use syn::{parse_quote, Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand Down Expand Up @@ -285,6 +285,51 @@ fn impl_enum(input: Enum) -> TokenStream {
} else {
None
};

let mut extra_predicates = Vec::new();
for (fields, display) in input
.variants
.iter()
.filter_map(|v| v.attrs.display.as_ref().map(|d| (v.fields.as_slice(), d)))
{
for field in display.iter_fmt_types(fields) {
use crate::attr::DisplayFormatMarking;
let (ty, bound): (
&syn::Type,
syn::punctuated::Punctuated<syn::TypeParamBound, _>,
);
match field {
DisplayFormatMarking::Debug(f) => {
ty = &f.ty;
bound = parse_quote! { ::std::fmt::Debug };
}
DisplayFormatMarking::Display(f) => {
ty = &f.ty;
bound = parse_quote! { ::std::fmt::Display };
// f.original.
}
}
let matcher = |param: &&syn::TypeParam| match ty {
syn::Type::Path(syn::TypePath { path, .. }) => {
path.get_ident() == Some(&param.ident)
}
syn::Type::Reference(syn::TypeReference { elem, .. }) if matches!(Box::as_ref(&elem), &syn::Type::Path(syn::TypePath { ref path, .. }) if path.get_ident() == Some(&param.ident)) => {
true
}
_ => false,
};
if input.generics.type_params().find(|x| matcher(x)).is_none() {
continue;
}
extra_predicates.push(syn::WherePredicate::Type(syn::PredicateType {
bounded_ty: ty.clone(),
colon_token: syn::token::Colon::default(),
bounds: bound,
lifetimes: None,
}));
}
}

let arms = input.variants.iter().map(|variant| {
let display = match &variant.attrs.display {
Some(display) => display.to_token_stream(),
Expand All @@ -302,6 +347,9 @@ fn impl_enum(input: Enum) -> TokenStream {
#ty::#ident #pat => #display
}
});

let where_clause = augment_where_clause(where_clause, extra_predicates);

Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
Expand Down Expand Up @@ -343,6 +391,10 @@ fn impl_enum(input: Enum) -> TokenStream {

let error_trait = spanned_error_trait(input.original);

let where_clause = augment_where_clause(
where_clause,
[parse_quote! { Self: ::std::fmt::Display + ::std::fmt::Debug }],
);
quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
Expand All @@ -354,6 +406,102 @@ fn impl_enum(input: Enum) -> TokenStream {
}
}

#[cfg_attr(test, derive(Debug))]
struct GenericUsageVisitor {
generics: std::collections::HashMap<syn::Ident, bool>,
}

impl GenericUsageVisitor {
pub fn new<TPairs>(generics: TPairs) -> Self
where
TPairs: IntoIterator<Item = (syn::Ident, bool)>,
{
Self {
generics: generics.into_iter().collect(),
}
}
}

impl<'ast> syn::visit::Visit<'ast> for GenericUsageVisitor {
fn visit_type_path(&mut self, i: &'ast syn::TypePath) {
if let Some(ident) = i.path.get_ident() {
if let Some(entry) = self.generics.get_mut(ident) {
*entry = true;
}
}
syn::visit::visit_type_path(self, i)
}

fn visit_type_param(&mut self, i: &'ast syn::TypeParam) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not reachable — except maybe in contrived code like field: [T; { fn f<T>() {}; 0 }] where you actually don't want the behavior here of considering the inner T TypeParam as being a use of the outer T.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this to deep-search for generic usages because- when a generic wasn't used- I was running into circumstances where fixed types would succeed in compilation where they should not, because it generated impossible/unlikely bounds, such as PathBuf: Display, which you seem to have resolved using the PathAsDisplay trait in the past. My solution was to deep-scan for generics in types, and only add bounds for fields which used at least one generic.

if let Some(entry) = self.generics.get_mut(&i.ident) {
*entry = true;
}
syn::visit::visit_type_param(self, i)
}
}

#[cfg(test)]
mod test_visitors {
use proc_macro2::Span;

use crate::expand::GenericUsageVisitor;

#[test]
fn test_generic_usage_visitor() {
fn ident_for<'a>(ident_str: &'a str) -> syn::Ident {
syn::Ident::new(ident_str, Span::call_site())
}

let field: syn::Variant = syn::parse_quote! { X(Foo<Bar>) };

let mut visitor = GenericUsageVisitor::new(
vec!["Bar", "Baz"]
.into_iter()
.map(|ident_str| (ident_for(ident_str), false)),
);
syn::visit::visit_variant(&mut visitor, &field);
assert_eq!(
visitor.generics.get(&ident_for("Bar")),
Some(&true),
"Bar must be marked"
);
assert_eq!(
visitor.generics.get(&ident_for("Baz")),
Some(&false),
"Baz must be present but not marked"
);
}
}

fn augment_where_clause<TPredicates>(
where_clause: Option<&syn::WhereClause>,
extra_predicates: TPredicates,
) -> Option<syn::WhereClause>
where
TPredicates: IntoIterator<Item = syn::WherePredicate>,
TPredicates::IntoIter: std::iter::ExactSizeIterator,
{
let extra_predicates = extra_predicates.into_iter();
if extra_predicates.len() == 0 {
return where_clause.cloned();
}
Some(match where_clause {
Some(w) => syn::WhereClause {
where_token: w.where_token.clone(),
predicates: w
.predicates
.iter()
.cloned()
.chain(extra_predicates)
.collect(),
},
None => syn::WhereClause {
where_token: syn::token::Where::default(),
predicates: extra_predicates.into_iter().collect(),
},
})
}

fn fields_pat(fields: &[Field]) -> TokenStream {
let mut members = fields.iter().map(|field| &field.member).peekable();
match members.peek() {
Expand Down
112 changes: 112 additions & 0 deletions tests/test_generics.rs
@@ -0,0 +1,112 @@
#![deny(clippy::all, clippy::pedantic)]
#![allow(
// Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
clippy::nonstandard_macro_braces,
)]
#![allow(dead_code)]

use std::fmt::{Debug, Display};

use thiserror::Error;

struct NoFormattingType;

struct DisplayType;

impl Display for DisplayType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, stringify!(DisplayType))
}
}

impl Debug for DisplayType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, stringify!(DisplayType))
}
}

struct DebugType;

impl Debug for DebugType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, stringify!(DebugType))
}
}

/// Direct embedding of a generic in a field
///
/// Should produce the following instances:
///
/// ```rust
/// impl<Embedded> Display for DirectEmbedding<Embedded>
/// where
/// Embedded: Debug;
///
/// impl<Embedded> Error for DirectEmbedding<Embedded>
/// where
/// Self: Debug + Display;
/// ```
#[derive(Error, Debug)]
enum DirectEmbedding<Embedded> {
#[error("{0:?}")]
FatalError(Embedded),
}

/// #[from] handling but no Debug usage of the generic
///
/// Should produce the following instances:
///
/// ```rust
/// impl<Indirect> Display for FromGenericError<Indirect>;
///
/// impl<Indirect> Error for FromGenericError<Indirect>
/// where
/// SourceGenericError<Indirect>: Error,
/// Indirect: 'static,
/// Self: Debug + Display;
/// ```
// TODO: Parse generic parameters contained in #[from] usages and add them to Error impl
// #[derive(Error, Debug)]
// enum FromGenericError<Indirect> {
// #[error("Tadah")]
// SourceEmbedded(#[from] DirectEmbedding<Indirect>),
// }

/// Direct embedding of a generic in a field
///
/// Should produce the following instances:
///
/// ```rust
/// impl<HasDisplay, HasDebug> Display for DirectEmbedding<HasDisplay, HasDebug>
/// where
/// HasDisplay: Display,
/// HasDebug: Debug;
///
/// impl<HasDisplay, HasDebug> Error for DirectEmbedding<HasDisplay, HasDebug>
/// where
/// Self: Debug + Display;
/// ```
#[derive(Error)]
enum HybridDisplayType<HasDisplay, HasDebug, HasNeither> {
#[error("{0} : {1:?}")]
HybridDisplayCase(HasDisplay, HasDebug),
#[error("{0}")]
DisplayCase(HasDisplay, HasNeither),
#[error("{1:?}")]
DebugCase(HasNeither, HasDebug),
}

impl<HasDisplay, HasDebug, HasNeither> Debug
for HybridDisplayType<HasDisplay, HasDebug, HasNeither>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HybridDisplayType")
}
}

fn display_hybrid_display_type(
instance: HybridDisplayType<DisplayType, DebugType, NoFormattingType>,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
Debug::fmt(&instance, f)
}