diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 61db3eb9ae..399c38aff4 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -1,6 +1,11 @@ +use std::collections::HashSet; + use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{parse::Parse, ItemStruct, LitStr, Token}; +use syn::{ + parse::Parse, parse_quote, punctuated::Punctuated, Generics, ItemStruct, LitStr, Token, + WhereClause, WherePredicate, +}; use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine}; @@ -11,16 +16,9 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { generics, fields, .. - } = &item_struct; - - if !generics.params.is_empty() || generics.where_clause.is_some() { - return Err(syn::Error::new_spanned( - generics, - "`#[derive(TypedPath)]` doesn't support generics", - )); - } + } = item_struct; - let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?; + let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", &attrs)?; let path = path.ok_or_else(|| { syn::Error::new( @@ -32,15 +30,17 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { let rejection = rejection.map(second); match fields { - syn::Fields::Named(_) => { + syn::Fields::Named(fields) => { let segments = parse_path(&path)?; - Ok(expand_named_fields(ident, path, &segments, rejection)) + Ok(expand_named_fields( + fields, ident, path, &segments, rejection, generics, + )) } syn::Fields::Unnamed(fields) => { let segments = parse_path(&path)?; - expand_unnamed_fields(fields, ident, path, &segments, rejection) + expand_unnamed_fields(fields, ident, path, &segments, rejection, generics) } - syn::Fields::Unit => expand_unit_fields(ident, path, rejection), + syn::Fields::Unit => expand_unit_fields(ident, path, rejection, generics), } } @@ -94,24 +94,51 @@ impl Combine for Attrs { } fn expand_named_fields( - ident: &syn::Ident, + fields: syn::FieldsNamed, + ident: syn::Ident, path: LitStr, segments: &[Segment], rejection: Option, + generics: Generics, ) -> TokenStream { let format_str = format_str_from_path(segments); let captures = captures_from_path(segments); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // use field types here to avoid unneeded bounds on generics + // note: this might introduce trivial bounds like i32: Display + // even if that isn't required per se, but that isn't really a issue + // since they would be implied by the usage anyways + let field_types = fields + .named + .iter() + .map(|field| &field.ty) + .cloned() + .collect::>(); + + let path_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! { #ty: ::std::fmt::Display }, + ); + let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::axum_extra::routing::TypedPath for #ident { + impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #path_where_clause { const PATH: &'static str = #path; } }; + let display_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! { #ty: ::std::fmt::Display }, + ); + let display_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::std::fmt::Display for #ident { + impl #impl_generics ::std::fmt::Display for #ident #ty_generics #display_where_clause { #[allow(clippy::unnecessary_to_owned)] fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { let Self { #(#captures,)* } = self; @@ -132,18 +159,28 @@ fn expand_named_fields( let rejection_assoc_type = rejection_assoc_type(&rejection); let map_err_rejection = map_err_rejection(&rejection); + let parts_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! { for<'de> #ty: Send + Sync + ::serde::Deserialize<'de> }, + ); + + let mut parts_generics = generics.clone(); + + parts_generics + .params + .push(parse_quote! {__Derived_S: Send + Sync}); + let (impl_generics, _, _) = parts_generics.split_for_impl(); + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequestParts for #ident - where - S: Send + Sync, - { + impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #parts_where_clause { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - state: &S, + state: &__Derived_S, ) -> ::std::result::Result { ::axum::extract::Path::from_request_parts(parts, state) .await @@ -161,11 +198,12 @@ fn expand_named_fields( } fn expand_unnamed_fields( - fields: &syn::FieldsUnnamed, - ident: &syn::Ident, + fields: syn::FieldsUnnamed, + ident: syn::Ident, path: LitStr, segments: &[Segment], rejection: Option, + generics: Generics, ) -> syn::Result { let num_captures = segments .iter() @@ -204,19 +242,44 @@ fn expand_unnamed_fields( } }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // use field types here to avoid unneeded bounds on generics + // note: this might introduce trivial bounds like i32: Display + // even if that isn't required per se, but that isn't really a issue + // since they would be implied by the usage anyways + let field_types = fields + .unnamed + .iter() + .map(|field| &field.ty) + .cloned() + .collect::>(); + + let path_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! { #ty: ::std::fmt::Display }, + ); + let format_str = format_str_from_path(segments); let captures = captures_from_path(segments); let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::axum_extra::routing::TypedPath for #ident { + impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #path_where_clause { const PATH: &'static str = #path; } }; + let display_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! {#ty: ::std::fmt::Display}, + ); + let display_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::std::fmt::Display for #ident { + impl #impl_generics ::std::fmt::Display for #ident #ty_generics #display_where_clause { #[allow(clippy::unnecessary_to_owned)] fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { let Self { #(#destructure_self)* } = self; @@ -237,18 +300,28 @@ fn expand_unnamed_fields( let rejection_assoc_type = rejection_assoc_type(&rejection); let map_err_rejection = map_err_rejection(&rejection); + let parts_where_clause = add_where_bounds_for_types( + where_clause, + &field_types, + |ty| parse_quote! {for<'de> #ty: Send + Sync + ::serde::Deserialize<'de> }, + ); + + let mut parts_generics = generics.clone(); + + parts_generics + .params + .push(parse_quote! {__Derived_S : Send + Sync}); + let (impl_generics, _, _) = parts_generics.split_for_impl(); + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequestParts for #ident - where - S: Send + Sync, - { + impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #parts_where_clause { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - state: &S, + state: &__Derived_S, ) -> ::std::result::Result { ::axum::extract::Path::from_request_parts(parts, state) .await @@ -274,9 +347,10 @@ fn simple_pluralize(count: usize, word: &str) -> String { } fn expand_unit_fields( - ident: &syn::Ident, + ident: syn::Ident, path: LitStr, rejection: Option, + generics: Generics, ) -> syn::Result { for segment in parse_path(&path)? { match segment { @@ -290,16 +364,18 @@ fn expand_unit_fields( } } + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::axum_extra::routing::TypedPath for #ident { + impl #impl_generics ::axum_extra::routing::TypedPath for #ident #ty_generics #where_clause { const PATH: &'static str = #path; } }; let display_impl = quote_spanned! {path.span()=> #[automatically_derived] - impl ::std::fmt::Display for #ident { + impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { write!(f, #path) } @@ -321,18 +397,23 @@ fn expand_unit_fields( } }; + let mut parts_generics = generics.clone(); + + parts_generics + .params + .push(parse_quote! {__Derived_S: Send + Sync}); + + let (impl_generics, _, _) = parts_generics.split_for_impl(); + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequestParts for #ident - where - S: Send + Sync, - { + impl #impl_generics ::axum::extract::FromRequestParts<__Derived_S> for #ident #ty_generics #where_clause { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - _state: &S, + _state: &__Derived_S, ) -> ::std::result::Result { if parts.uri.path() == ::PATH { Ok(Self) @@ -404,7 +485,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequestParts>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequestParts<__Derived_S>>::Rejection } } @@ -429,6 +510,31 @@ fn map_err_rejection(rejection: &Option) -> TokenStream { .unwrap_or_default() } +fn empty_where_clause() -> WhereClause { + WhereClause { + where_token: Token![where](Span::mixed_site()), + predicates: Punctuated::new(), + } +} + +fn add_where_bounds_for_types<'a, 'b>( + where_clause: Option<&'a WhereClause>, + types: impl IntoIterator, + bound: impl Fn(&'b syn::Type) -> WherePredicate, +) -> Option { + let mut peekable_types = types.into_iter().peekable(); + + peekable_types.peek()?; + + let mut where_clause = where_clause.cloned().unwrap_or_else(empty_where_clause); + + for ty in peekable_types { + where_clause.predicates.push(bound(ty)); + } + + Some(where_clause) +} + #[test] fn ui() { crate::run_ui_tests("typed_path"); diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 8513e2aeaf..cbdfc326ab 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -1,21 +1,21 @@ -error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied +error[E0277]: the trait bound `MyPath: DeserializeOwned` is not satisfied --> tests/typed_path/fail/not_deserialize.rs:3:10 | 3 | #[derive(TypedPath)] - | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` + | ^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts<__Derived_S>` | = help: the trait `FromRequestParts` is implemented for `axum::extract::Path` - = note: required for `MyPath` to implement `serde::de::DeserializeOwned` - = note: required for `axum::extract::Path` to implement `FromRequestParts` + = note: required for `MyPath` to implement `DeserializeOwned` + = note: required for `axum::extract::Path` to implement `FromRequestParts<__Derived_S>` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied +error[E0277]: the trait bound `MyPath: DeserializeOwned` is not satisfied --> tests/typed_path/fail/not_deserialize.rs:3:10 | 3 | #[derive(TypedPath)] - | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` + | ^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts<__Derived_S>` | = help: the trait `FromRequestParts` is implemented for `axum::extract::Path` - = note: required for `MyPath` to implement `serde::de::DeserializeOwned` - = note: required for `axum::extract::Path` to implement `FromRequestParts` + = note: required for `MyPath` to implement `DeserializeOwned` + = note: required for `axum::extract::Path` to implement `FromRequestParts<__Derived_S>` = note: this error originates in the attribute macro `::axum::async_trait` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/typed_path/pass/generics.rs b/axum-macros/tests/typed_path/pass/generics.rs new file mode 100644 index 0000000000..7165c1a3ac --- /dev/null +++ b/axum-macros/tests/typed_path/pass/generics.rs @@ -0,0 +1,45 @@ +use axum_extra::routing::{RouterExt, TypedPath}; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/:foo")] +struct MyPathNamed { + foo: T, +} + +// types with wrappers should not get any where bounds auto-generated +struct WrapperStruct(T); + +impl<'de, T> Deserialize<'de> for WrapperStruct { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + unimplemented!() + } +} + +impl std::fmt::Display for WrapperStruct { + fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } +} + +#[derive(TypedPath)] +#[typed_path("/:foo/:bar")] +struct MyPathUnnamed(T, WrapperStruct); + +impl<'de, T, U> Deserialize<'de> for MyPathUnnamed { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + unimplemented!() + } +} + +fn main() { + _ = axum::Router::<()>::new() + .typed_get(|_: MyPathNamed| async {}) + .typed_post(|_: MyPathUnnamed| async {}) +}