Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop limitation of disallowing generics for the TypedPath derive in axum-extra #2723

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
186 changes: 146 additions & 40 deletions axum-macros/src/typed_path.rs
@@ -1,6 +1,11 @@
use std::collections::HashSet;

use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{parse::Parse, ItemStruct, LitStr, Token};
use syn::{
parse::Parse, parse_quote, punctuated::Punctuated, Generics, ItemStruct, LitStr, Token,
WhereClause, WherePredicate,
};

use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine};

Expand All @@ -11,16 +16,9 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
generics,
fields,
..
} = &item_struct;

if !generics.params.is_empty() || generics.where_clause.is_some() {
return Err(syn::Error::new_spanned(
generics,
"`#[derive(TypedPath)]` doesn't support generics",
));
}
} = item_struct;

let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?;
let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", &attrs)?;

let path = path.ok_or_else(|| {
syn::Error::new(
Expand All @@ -32,15 +30,17 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
let rejection = rejection.map(second);

match fields {
syn::Fields::Named(_) => {
syn::Fields::Named(fields) => {
let segments = parse_path(&path)?;
Ok(expand_named_fields(ident, path, &segments, rejection))
Ok(expand_named_fields(
fields, ident, path, &segments, rejection, generics,
))
}
syn::Fields::Unnamed(fields) => {
let segments = parse_path(&path)?;
expand_unnamed_fields(fields, ident, path, &segments, rejection)
expand_unnamed_fields(fields, ident, path, &segments, rejection, generics)
}
syn::Fields::Unit => expand_unit_fields(ident, path, rejection),
syn::Fields::Unit => expand_unit_fields(ident, path, rejection, generics),
}
}

Expand Down Expand Up @@ -94,24 +94,51 @@ impl Combine for Attrs {
}

fn expand_named_fields(
ident: &syn::Ident,
fields: syn::FieldsNamed,
ident: syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
generics: Generics,
) -> TokenStream {
let format_str = format_str_from_path(segments);
let captures = captures_from_path(segments);

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

// use field types here to avoid unneeded bounds on generics
// note: this might introduce trivial bounds like i32: Display
// even if that isn't required per se, but that isn't really a issue
// since they would be implied by the usage anyways
let field_types = fields
.named
.iter()
.map(|field| &field.ty)
.cloned()
.collect::<HashSet<_>>();

let path_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! { #ty: ::std::fmt::Display },
);

let typed_path_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::axum_extra::routing::TypedPath for #ident {
impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #path_where_clause {
const PATH: &'static str = #path;
}
};

let display_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! { #ty: ::std::fmt::Display },
);

let display_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::std::fmt::Display for #ident {
impl #impl_generics ::std::fmt::Display for #ident #ty_generics #display_where_clause {
#[allow(clippy::unnecessary_to_owned)]
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
let Self { #(#captures,)* } = self;
Expand All @@ -132,18 +159,28 @@ fn expand_named_fields(
let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let parts_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! { for<'de> #ty: Send + Sync + ::serde::Deserialize<'de> },
);

let mut parts_generics = generics.clone();

parts_generics
.params
.push(parse_quote! {__Derived_S: Send + Sync});
let (impl_generics, _, _) = parts_generics.split_for_impl();

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
S: Send + Sync,
{
impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #parts_where_clause {
type Rejection = #rejection_assoc_type;

async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S,
state: &__Derived_S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await
Expand All @@ -161,11 +198,12 @@ fn expand_named_fields(
}

fn expand_unnamed_fields(
fields: &syn::FieldsUnnamed,
ident: &syn::Ident,
fields: syn::FieldsUnnamed,
ident: syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
generics: Generics,
) -> syn::Result<TokenStream> {
let num_captures = segments
.iter()
Expand Down Expand Up @@ -204,19 +242,44 @@ fn expand_unnamed_fields(
}
});

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

// use field types here to avoid unneeded bounds on generics
// note: this might introduce trivial bounds like i32: Display
// even if that isn't required per se, but that isn't really a issue
// since they would be implied by the usage anyways
let field_types = fields
.unnamed
.iter()
.map(|field| &field.ty)
.cloned()
.collect::<HashSet<_>>();

let path_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! { #ty: ::std::fmt::Display },
);

let format_str = format_str_from_path(segments);
let captures = captures_from_path(segments);

let typed_path_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::axum_extra::routing::TypedPath for #ident {
impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #path_where_clause {
const PATH: &'static str = #path;
}
};

let display_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! {#ty: ::std::fmt::Display},
);

let display_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::std::fmt::Display for #ident {
impl #impl_generics ::std::fmt::Display for #ident #ty_generics #display_where_clause {
#[allow(clippy::unnecessary_to_owned)]
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
let Self { #(#destructure_self)* } = self;
Expand All @@ -237,18 +300,28 @@ fn expand_unnamed_fields(
let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let parts_where_clause = add_where_bounds_for_types(
where_clause,
&field_types,
|ty| parse_quote! {for<'de> #ty: Send + Sync + ::serde::Deserialize<'de> },
);

let mut parts_generics = generics.clone();

parts_generics
.params
.push(parse_quote! {__Derived_S : Send + Sync});
let (impl_generics, _, _) = parts_generics.split_for_impl();

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
S: Send + Sync,
{
impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #parts_where_clause {
type Rejection = #rejection_assoc_type;

async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S,
state: &__Derived_S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await
Expand All @@ -274,9 +347,10 @@ fn simple_pluralize(count: usize, word: &str) -> String {
}

fn expand_unit_fields(
ident: &syn::Ident,
ident: syn::Ident,
path: LitStr,
rejection: Option<syn::Path>,
generics: Generics,
) -> syn::Result<TokenStream> {
for segment in parse_path(&path)? {
match segment {
Expand All @@ -290,16 +364,18 @@ fn expand_unit_fields(
}
}

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let typed_path_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::axum_extra::routing::TypedPath for #ident {
impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #where_clause {
const PATH: &'static str = #path;
}
};

let display_impl = quote_spanned! {path.span()=>
#[automatically_derived]
impl ::std::fmt::Display for #ident {
impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
write!(f, #path)
}
Expand All @@ -321,18 +397,23 @@ fn expand_unit_fields(
}
};

let mut parts_generics = generics.clone();

parts_generics
.params
.push(parse_quote! {__Derived_S: Send + Sync});

let (impl_generics, _, _) = parts_generics.split_for_impl();

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
S: Send + Sync,
{
impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #where_clause {
type Rejection = #rejection_assoc_type;

async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
_state: &S,
_state: &__Derived_S,
) -> ::std::result::Result<Self, Self::Rejection> {
if parts.uri.path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self)
Expand Down Expand Up @@ -404,7 +485,7 @@ enum Segment {

fn path_rejection() -> TokenStream {
quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection
<::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<__Derived_S>>::Rejection
}
}

Expand All @@ -429,6 +510,31 @@ fn map_err_rejection(rejection: &Option<syn::Path>) -> TokenStream {
.unwrap_or_default()
}

fn empty_where_clause() -> WhereClause {
WhereClause {
where_token: Token![where](Span::mixed_site()),
predicates: Punctuated::new(),
}
}

fn add_where_bounds_for_types<'a, 'b>(
where_clause: Option<&'a WhereClause>,
types: impl IntoIterator<Item = &'b syn::Type>,
bound: impl Fn(&'b syn::Type) -> WherePredicate,
) -> Option<WhereClause> {
let mut peekable_types = types.into_iter().peekable();

peekable_types.peek()?;

let mut where_clause = where_clause.cloned().unwrap_or_else(empty_where_clause);

for ty in peekable_types {
where_clause.predicates.push(bound(ty));
}

Some(where_clause)
}

#[test]
fn ui() {
crate::run_ui_tests("typed_path");
Expand Down
16 changes: 8 additions & 8 deletions axum-macros/tests/typed_path/fail/not_deserialize.stderr
@@ -1,21 +1,21 @@
error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied
error[E0277]: the trait bound `MyPath: DeserializeOwned` is not satisfied
--> tests/typed_path/fail/not_deserialize.rs:3:10
|
3 | #[derive(TypedPath)]
| ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path<MyPath>: FromRequestParts<S>`
| ^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path<MyPath>: FromRequestParts<__Derived_S>`
|
= help: the trait `FromRequestParts<S>` is implemented for `axum::extract::Path<T>`
= note: required for `MyPath` to implement `serde::de::DeserializeOwned`
= note: required for `axum::extract::Path<MyPath>` to implement `FromRequestParts<S>`
= note: required for `MyPath` to implement `DeserializeOwned`
= note: required for `axum::extract::Path<MyPath>` to implement `FromRequestParts<__Derived_S>`
= note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied
error[E0277]: the trait bound `MyPath: DeserializeOwned` is not satisfied
--> tests/typed_path/fail/not_deserialize.rs:3:10
|
3 | #[derive(TypedPath)]
| ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path<MyPath>: FromRequestParts<S>`
| ^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path<MyPath>: FromRequestParts<__Derived_S>`
|
= help: the trait `FromRequestParts<S>` is implemented for `axum::extract::Path<T>`
= note: required for `MyPath` to implement `serde::de::DeserializeOwned`
= note: required for `axum::extract::Path<MyPath>` to implement `FromRequestParts<S>`
= note: required for `MyPath` to implement `DeserializeOwned`
= note: required for `axum::extract::Path<MyPath>` to implement `FromRequestParts<__Derived_S>`
= note: this error originates in the attribute macro `::axum::async_trait` (in Nightly builds, run with -Z macro-backtrace for more info)