From 221fffed31f8156c18d99b8dd09bf9bc55f9290a Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 11 Nov 2022 13:22:21 +0100 Subject: [PATCH 1/3] fallback inheritance --- axum/src/docs/routing/nest.md | 39 ++++++----- axum/src/extract/matched_path.rs | 2 +- axum/src/routing/method_routing.rs | 8 +-- axum/src/routing/mod.rs | 102 +++++++++++++++++++++++------ axum/src/routing/not_found.rs | 18 +++-- axum/src/routing/route.rs | 4 +- axum/src/routing/service.rs | 36 ++++++++-- axum/src/routing/tests/fallback.rs | 100 ++++++++++++++++++++++++++++ axum/src/routing/tests/nest.rs | 17 ----- 9 files changed, 256 insertions(+), 70 deletions(-) diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 3b03d4d9c6..6ea05478c4 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -90,8 +90,8 @@ let app = Router::new() # Fallbacks -When nesting a router, if a request matches the prefix but the nested router doesn't have a matching -route, the outer fallback will _not_ be called: +If a nested router doesn't have its own fallback then it will inherit the +fallback from the outer router: ```rust use axum::{routing::get, http::StatusCode, handler::Handler, Router}; @@ -100,7 +100,7 @@ async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } -let api_routes = Router::new().nest_service("/users", get(|| async {})); +let api_routes = Router::new().route("/users", get(|| async {})); let app = Router::new() .nest("/api", api_routes) @@ -108,30 +108,35 @@ let app = Router::new() # let _: Router = app; ``` -Here requests like `GET /api/not-found` will go into `api_routes` and then to -the fallback of `api_routes` which will return an empty `404 Not Found` -response. The outer fallback declared on `app` will _not_ be called. +Here requests like `GET /api/not-found` will go into `api_routes` but because +it doesn't have a matching route and doesn't have its own fallback it will call +the fallback from the outer router, i.e. the `fallback` function. -Think of nested services as swallowing requests that matches the prefix and -not falling back to outer router even if they don't have a matching route. - -You can still add separate fallbacks to nested routers: +If the nested router has its own fallback then the outer fallback will not be +inherited: ```rust -use axum::{routing::get, http::StatusCode, handler::Handler, Json, Router}; -use serde_json::{json, Value}; +use axum::{ + routing::get, + http::StatusCode, + handler::Handler, + Json, + Router, +}; async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } -async fn api_fallback() -> (StatusCode, Json) { - (StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found" }))) +async fn api_fallback() -> (StatusCode, Json) { + ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "status": "Not Found" })), + ) } let api_routes = Router::new() - .nest_service("/users", get(|| async {})) - // add dedicated fallback for requests starting with `/api` + .route("/users", get(|| async {})) .fallback(api_fallback); let app = Router::new() @@ -140,6 +145,8 @@ let app = Router::new() # let _: Router = app; ``` +Here requests like `GET /api/not-found` will go to `api_fallback`. + # Panics - If the route overlaps with another route. See [`Router::route`] diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 9ad5457bd0..064a9726c3 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -143,7 +143,7 @@ pub(crate) fn set_matched_path_for_request( if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) { extensions.insert(MatchedNestedPath(matched_path)); - debug_assert!(matches!(dbg!(extensions.remove::()), None)); + debug_assert!(matches!(extensions.remove::(), None)); } else { extensions.insert(MatchedPath(matched_path)); extensions.remove::(); diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index e8f3887ae0..3e870420c9 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1,6 +1,6 @@ //! Route to services and handlers based on HTTP methods. -use super::IntoMakeService; +use super::{FallbackRoute, IntoMakeService}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -744,7 +744,7 @@ where post: self.post.into_route(&state), put: self.put.into_route(&state), trace: self.trace.into_route(&state), - fallback: self.fallback.into_route(&state), + fallback: self.fallback.into_fallback_route(&state), allow_header: self.allow_header, } } @@ -1284,7 +1284,7 @@ pub struct WithState { post: Option>, put: Option>, trace: Option>, - fallback: Route, + fallback: FallbackRoute, allow_header: AllowHeader, } @@ -1346,7 +1346,7 @@ impl fmt::Debug for WithState { impl Service> for WithState where - B: HttpBody + Send, + B: HttpBody + Send + 'static, { type Response = Response; type Error = E; diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index c5156a4799..fdc95e977f 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::not_found::NotFound; +use self::{future::RouteFuture, not_found::NotFound}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -8,7 +8,7 @@ use crate::{ handler::{BoxedHandler, Handler}, util::try_downcast, }; -use axum_core::response::IntoResponse; +use axum_core::response::{IntoResponse, Response}; use http::Request; use matchit::MatchError; use std::{ @@ -17,8 +17,12 @@ use std::{ convert::Infallible, fmt, sync::Arc, + task::{Context, Poll}, +}; +use tower::{ + util::{BoxCloneService, MapResponseLayer, Oneshot}, + ServiceBuilder, }; -use tower::{util::MapResponseLayer, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; @@ -639,11 +643,29 @@ where } } - fn into_route(self, state: &S) -> Route { + fn into_fallback_route(self, state: &S) -> FallbackRoute { + match self { + Self::Default(route) => FallbackRoute::Default(route), + Self::Service(route) => FallbackRoute::Service(route), + Self::BoxedHandler(handler) => { + FallbackRoute::Service(handler.into_route(state.clone())) + } + } + } + + fn map(self, f: F) -> Fallback + where + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route) -> Route + Clone + Send + 'static, + B2: 'static, + E2: 'static, + { match self { - Self::Default(route) => route, - Self::Service(route) => route, - Self::BoxedHandler(handler) => handler.into_route(state.clone()), + Self::Default(inner) => Fallback::Default(f(inner)), + Self::Service(inner) => Fallback::Service(f(inner)), + Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)), } } } @@ -668,20 +690,60 @@ impl fmt::Debug for Fallback { } } -impl Fallback { - fn map(self, f: F) -> Fallback - where - S: 'static, - B: 'static, - E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, - E2: 'static, - { +/// Like `Fallback` but without the `S` param so it can be stored in `RouterService` +pub(crate) enum FallbackRoute { + Default(Route), + Service(Route), +} + +impl fmt::Debug for FallbackRoute { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Default(inner) => Fallback::Default(f(inner)), - Self::Service(inner) => Fallback::Service(f(inner)), - Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)), + Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), + Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(), + } + } +} + +impl Clone for FallbackRoute { + fn clone(&self) -> Self { + match self { + Self::Default(inner) => Self::Default(inner.clone()), + Self::Service(inner) => Self::Service(inner.clone()), + } + } +} + +impl FallbackRoute { + pub(crate) fn oneshot_inner( + &mut self, + req: Request, + ) -> Oneshot, Response, E>, Request> { + match self { + FallbackRoute::Default(inner) => inner.oneshot_inner(req), + FallbackRoute::Service(inner) => inner.oneshot_inner(req), + } + } +} + +impl Service> for FallbackRoute +where + B: HttpBody + Send + 'static, +{ + type Response = Response; + type Error = E; + type Future = RouteFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + match self { + FallbackRoute::Default(inner) => inner.call(req), + FallbackRoute::Service(inner) => inner.call(req), } } } diff --git a/axum/src/routing/not_found.rs b/axum/src/routing/not_found.rs index dc3fec46ac..e6e55bfbea 100644 --- a/axum/src/routing/not_found.rs +++ b/axum/src/routing/not_found.rs @@ -1,11 +1,12 @@ -use crate::response::Response; +use crate::{response::Response, Extension}; use axum_core::response::IntoResponse; use http::{Request, StatusCode}; use std::{ convert::Infallible, - future::ready, + future::{ready, Ready}, task::{Context, Poll}, }; +use sync_wrapper::SyncWrapper; use tower_service::Service; /// A [`Service`] that responds with `404 Not Found` to all requests. @@ -21,14 +22,21 @@ where { type Response = Response; type Error = Infallible; - type Future = std::future::Ready>; + type Future = Ready>; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Request) -> Self::Future { - ready(Ok(StatusCode::NOT_FOUND.into_response())) + fn call(&mut self, req: Request) -> Self::Future { + let res = ( + StatusCode::NOT_FOUND, + Extension(FromDefaultFallback(req.map(SyncWrapper::new))), + ) + .into_response(); + ready(Ok(res)) } } + +pub(super) struct FromDefaultFallback(pub(super) Request>); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 4b735a8d05..aafad75f7d 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -62,7 +62,7 @@ impl fmt::Debug for Route { impl Service> for Route where - B: HttpBody, + B: HttpBody + Send + 'static, { type Response = Response; type Error = E; @@ -129,7 +129,7 @@ impl RouteFuture { impl Future for RouteFuture where - B: HttpBody, + B: HttpBody + Send + 'static, { type Output = Result; diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index 3bd9a629f3..f19f385677 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -1,4 +1,6 @@ -use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router}; +use super::{ + future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router, +}; use crate::{ body::{Body, HttpBody}, response::Response, @@ -11,6 +13,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use sync_wrapper::SyncWrapper; use tower::Service; /// A [`Router`] converted into a [`Service`]. @@ -18,7 +21,7 @@ use tower::Service; pub struct RouterService { routes: HashMap>, node: Arc, - fallback: Route, + fallback: FallbackRoute, } impl RouterService @@ -52,7 +55,7 @@ where Self { routes, node: router.node, - fallback: router.fallback.into_route(&state), + fallback: router.fallback.into_fallback_route(&state), } } @@ -121,12 +124,35 @@ where let path = req.uri().path().to_owned(); match self.node.at(&path) { - Ok(match_) => self.call_route(match_, req), + Ok(match_) => { + match &self.fallback { + FallbackRoute::Default(_) => {} + FallbackRoute::Service(fallback) => { + req.extensions_mut() + .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); + } + } + + self.call_route(match_, req) + } Err( MatchError::NotFound | MatchError::ExtraTrailingSlash | MatchError::MissingTrailingSlash, - ) => self.fallback.clone().call(req), + ) => match &mut self.fallback { + FallbackRoute::Default(fallback) => { + if let Some(super_fallback) = req.extensions_mut().remove::>() + { + let mut super_fallback = super_fallback.0.into_inner(); + super_fallback.call(req) + } else { + fallback.call(req) + } + } + FallbackRoute::Service(fallback) => fallback.call(req), + }, } } } + +struct SuperFallback(SyncWrapper>); diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 4da166baea..b9993ab806 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -1,4 +1,5 @@ use super::*; +use crate::middleware::{map_request, map_response}; #[tokio::test] async fn basic() { @@ -58,3 +59,102 @@ async fn fallback_accessing_state() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "state"); } + +async fn inner_fallback() -> impl IntoResponse { + (StatusCode::NOT_FOUND, "inner") +} + +async fn outer_fallback() -> impl IntoResponse { + (StatusCode::NOT_FOUND, "outer") +} + +#[tokio::test] +async fn nested_router_inherits_fallback() { + let inner = Router::new(); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn doesnt_inherit_fallback_if_overriden() { + let inner = Router::new().fallback(inner_fallback); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "inner"); +} + +#[tokio::test] +async fn deeply_nested_inherit_from_top() { + let app = Router::new() + .nest("/foo", Router::new().nest("/bar", Router::new())) + .fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn deeply_nested_inherit_from_middle() { + let app = Router::new().nest( + "/foo", + Router::new() + .nest("/bar", Router::new()) + .fallback(outer_fallback), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn with_middleware_on_inner_fallback() { + async fn never_called(_: Request) -> Request { + panic!("should never be called") + } + + let inner = Router::new().layer(map_request(never_called)); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn also_inherits_default_layered_fallback() { + async fn set_header(mut res: Response) -> Response { + res.headers_mut() + .insert("x-from-fallback", "1".parse().unwrap()); + res + } + + let inner = Router::new(); + let app = Router::new() + .nest("/foo", inner) + .fallback(outer_fallback) + .layer(map_response(set_header)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.headers()["x-from-fallback"], "1"); + assert_eq!(res.text().await, "outer"); +} diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 42c4ffe94d..b53740c7b4 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -350,23 +350,6 @@ async fn nest_with_and_without_trailing() { assert_eq!(res.status(), StatusCode::OK); } -#[tokio::test] -async fn doesnt_call_outer_fallback() { - let app = Router::new() - .nest("/foo", Router::new().route("/", get(|| async {}))) - .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); - - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - - let res = client.get("/foo/not-found").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - // the default fallback returns an empty body - assert_eq!(res.text().await, ""); -} - #[tokio::test] async fn nesting_with_root_inner_router() { let app = Router::new().nest( From 3a9547e5895df2fadbcc89ce3e642aa142bc895d Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 11 Nov 2022 13:35:32 +0100 Subject: [PATCH 2/3] cleanup --- axum/src/routing/method_routing.rs | 2 +- axum/src/routing/mod.rs | 31 ++++-------------------------- axum/src/routing/not_found.rs | 18 +++++------------ axum/src/routing/route.rs | 4 ++-- 4 files changed, 12 insertions(+), 43 deletions(-) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 3e870420c9..8672b1b2c6 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1346,7 +1346,7 @@ impl fmt::Debug for WithState { impl Service> for WithState where - B: HttpBody + Send + 'static, + B: HttpBody + Send, { type Response = Response; type Error = E; diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index fdc95e977f..88a23127f6 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::{future::RouteFuture, not_found::NotFound}; +use self::not_found::NotFound; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -17,7 +17,6 @@ use std::{ convert::Infallible, fmt, sync::Arc, - task::{Context, Poll}, }; use tower::{ util::{BoxCloneService, MapResponseLayer, Oneshot}, @@ -663,9 +662,9 @@ where E2: 'static, { match self { - Self::Default(inner) => Fallback::Default(f(inner)), - Self::Service(inner) => Fallback::Service(f(inner)), - Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)), + Self::Default(route) => Fallback::Default(f(route)), + Self::Service(route) => Fallback::Service(f(route)), + Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)), } } } @@ -726,28 +725,6 @@ impl FallbackRoute { } } -impl Service> for FallbackRoute -where - B: HttpBody + Send + 'static, -{ - type Response = Response; - type Error = E; - type Future = RouteFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&mut self, req: Request) -> Self::Future { - match self { - FallbackRoute::Default(inner) => inner.call(req), - FallbackRoute::Service(inner) => inner.call(req), - } - } -} - #[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine enum Endpoint { MethodRouter(MethodRouter), diff --git a/axum/src/routing/not_found.rs b/axum/src/routing/not_found.rs index e6e55bfbea..dc3fec46ac 100644 --- a/axum/src/routing/not_found.rs +++ b/axum/src/routing/not_found.rs @@ -1,12 +1,11 @@ -use crate::{response::Response, Extension}; +use crate::response::Response; use axum_core::response::IntoResponse; use http::{Request, StatusCode}; use std::{ convert::Infallible, - future::{ready, Ready}, + future::ready, task::{Context, Poll}, }; -use sync_wrapper::SyncWrapper; use tower_service::Service; /// A [`Service`] that responds with `404 Not Found` to all requests. @@ -22,21 +21,14 @@ where { type Response = Response; type Error = Infallible; - type Future = Ready>; + type Future = std::future::Ready>; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { - let res = ( - StatusCode::NOT_FOUND, - Extension(FromDefaultFallback(req.map(SyncWrapper::new))), - ) - .into_response(); - ready(Ok(res)) + fn call(&mut self, _req: Request) -> Self::Future { + ready(Ok(StatusCode::NOT_FOUND.into_response())) } } - -pub(super) struct FromDefaultFallback(pub(super) Request>); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index aafad75f7d..4b735a8d05 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -62,7 +62,7 @@ impl fmt::Debug for Route { impl Service> for Route where - B: HttpBody + Send + 'static, + B: HttpBody, { type Response = Response; type Error = E; @@ -129,7 +129,7 @@ impl RouteFuture { impl Future for RouteFuture where - B: HttpBody + Send + 'static, + B: HttpBody, { type Output = Result; From eaa0687cff75c4d98946be4cc3430ae8c605c2e5 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 11 Nov 2022 13:49:22 +0100 Subject: [PATCH 3/3] changelog --- axum/CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 17646e2673..05f08d216f 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521]) + +[#1521]: https://github.com/tokio-rs/axum/pull/1521 # 0.6.0-rc.4 (9. November, 2022)