From 7090649377adae7537ba4e479e6a60dca57ad846 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 18 Nov 2022 11:25:04 +0100 Subject: [PATCH] Add fallback inheritance for nested routers (#1521) * fallback inheritance * cleanup * changelog --- axum/CHANGELOG.md | 3 + axum/src/docs/routing/nest.md | 39 ++++++----- axum/src/routing/method_routing.rs | 6 +- axum/src/routing/mod.rs | 77 ++++++++++++++++------ axum/src/routing/service.rs | 36 +++++++++-- axum/src/routing/tests/fallback.rs | 100 +++++++++++++++++++++++++++++ axum/src/routing/tests/nest.rs | 17 ----- 7 files changed, 218 insertions(+), 60 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 4d0ae283a4..9f28fd6492 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,8 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521]) - **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) +[#1521]: https://github.com/tokio-rs/axum/pull/1521 + # 0.6.0-rc.4 (9. November, 2022) - **changed**: The inner error of a `JsonRejection` is now diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 3b03d4d9c6..6ea05478c4 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -90,8 +90,8 @@ let app = Router::new() # Fallbacks -When nesting a router, if a request matches the prefix but the nested router doesn't have a matching -route, the outer fallback will _not_ be called: +If a nested router doesn't have its own fallback then it will inherit the +fallback from the outer router: ```rust use axum::{routing::get, http::StatusCode, handler::Handler, Router}; @@ -100,7 +100,7 @@ async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } -let api_routes = Router::new().nest_service("/users", get(|| async {})); +let api_routes = Router::new().route("/users", get(|| async {})); let app = Router::new() .nest("/api", api_routes) @@ -108,30 +108,35 @@ let app = Router::new() # let _: Router = app; ``` -Here requests like `GET /api/not-found` will go into `api_routes` and then to -the fallback of `api_routes` which will return an empty `404 Not Found` -response. The outer fallback declared on `app` will _not_ be called. +Here requests like `GET /api/not-found` will go into `api_routes` but because +it doesn't have a matching route and doesn't have its own fallback it will call +the fallback from the outer router, i.e. the `fallback` function. -Think of nested services as swallowing requests that matches the prefix and -not falling back to outer router even if they don't have a matching route. - -You can still add separate fallbacks to nested routers: +If the nested router has its own fallback then the outer fallback will not be +inherited: ```rust -use axum::{routing::get, http::StatusCode, handler::Handler, Json, Router}; -use serde_json::{json, Value}; +use axum::{ + routing::get, + http::StatusCode, + handler::Handler, + Json, + Router, +}; async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } -async fn api_fallback() -> (StatusCode, Json) { - (StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found" }))) +async fn api_fallback() -> (StatusCode, Json) { + ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "status": "Not Found" })), + ) } let api_routes = Router::new() - .nest_service("/users", get(|| async {})) - // add dedicated fallback for requests starting with `/api` + .route("/users", get(|| async {})) .fallback(api_fallback); let app = Router::new() @@ -140,6 +145,8 @@ let app = Router::new() # let _: Router = app; ``` +Here requests like `GET /api/not-found` will go to `api_fallback`. + # Panics - If the route overlaps with another route. See [`Router::route`] diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index e8f3887ae0..8672b1b2c6 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1,6 +1,6 @@ //! Route to services and handlers based on HTTP methods. -use super::IntoMakeService; +use super::{FallbackRoute, IntoMakeService}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -744,7 +744,7 @@ where post: self.post.into_route(&state), put: self.put.into_route(&state), trace: self.trace.into_route(&state), - fallback: self.fallback.into_route(&state), + fallback: self.fallback.into_fallback_route(&state), allow_header: self.allow_header, } } @@ -1284,7 +1284,7 @@ pub struct WithState { post: Option>, put: Option>, trace: Option>, - fallback: Route, + fallback: FallbackRoute, allow_header: AllowHeader, } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index c5156a4799..88a23127f6 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -8,7 +8,7 @@ use crate::{ handler::{BoxedHandler, Handler}, util::try_downcast, }; -use axum_core::response::IntoResponse; +use axum_core::response::{IntoResponse, Response}; use http::Request; use matchit::MatchError; use std::{ @@ -18,7 +18,10 @@ use std::{ fmt, sync::Arc, }; -use tower::{util::MapResponseLayer, ServiceBuilder}; +use tower::{ + util::{BoxCloneService, MapResponseLayer, Oneshot}, + ServiceBuilder, +}; use tower_layer::Layer; use tower_service::Service; @@ -639,11 +642,29 @@ where } } - fn into_route(self, state: &S) -> Route { + fn into_fallback_route(self, state: &S) -> FallbackRoute { match self { - Self::Default(route) => route, - Self::Service(route) => route, - Self::BoxedHandler(handler) => handler.into_route(state.clone()), + Self::Default(route) => FallbackRoute::Default(route), + Self::Service(route) => FallbackRoute::Service(route), + Self::BoxedHandler(handler) => { + FallbackRoute::Service(handler.into_route(state.clone())) + } + } + } + + fn map(self, f: F) -> Fallback + where + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route) -> Route + Clone + Send + 'static, + B2: 'static, + E2: 'static, + { + match self { + Self::Default(route) => Fallback::Default(f(route)), + Self::Service(route) => Fallback::Service(f(route)), + Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)), } } } @@ -668,20 +689,38 @@ impl fmt::Debug for Fallback { } } -impl Fallback { - fn map(self, f: F) -> Fallback - where - S: 'static, - B: 'static, - E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, - E2: 'static, - { +/// Like `Fallback` but without the `S` param so it can be stored in `RouterService` +pub(crate) enum FallbackRoute { + Default(Route), + Service(Route), +} + +impl fmt::Debug for FallbackRoute { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), + Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(), + } + } +} + +impl Clone for FallbackRoute { + fn clone(&self) -> Self { + match self { + Self::Default(inner) => Self::Default(inner.clone()), + Self::Service(inner) => Self::Service(inner.clone()), + } + } +} + +impl FallbackRoute { + pub(crate) fn oneshot_inner( + &mut self, + req: Request, + ) -> Oneshot, Response, E>, Request> { match self { - Self::Default(inner) => Fallback::Default(f(inner)), - Self::Service(inner) => Fallback::Service(f(inner)), - Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)), + FallbackRoute::Default(inner) => inner.oneshot_inner(req), + FallbackRoute::Service(inner) => inner.oneshot_inner(req), } } } diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index 3bd9a629f3..f19f385677 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -1,4 +1,6 @@ -use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router}; +use super::{ + future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router, +}; use crate::{ body::{Body, HttpBody}, response::Response, @@ -11,6 +13,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use sync_wrapper::SyncWrapper; use tower::Service; /// A [`Router`] converted into a [`Service`]. @@ -18,7 +21,7 @@ use tower::Service; pub struct RouterService { routes: HashMap>, node: Arc, - fallback: Route, + fallback: FallbackRoute, } impl RouterService @@ -52,7 +55,7 @@ where Self { routes, node: router.node, - fallback: router.fallback.into_route(&state), + fallback: router.fallback.into_fallback_route(&state), } } @@ -121,12 +124,35 @@ where let path = req.uri().path().to_owned(); match self.node.at(&path) { - Ok(match_) => self.call_route(match_, req), + Ok(match_) => { + match &self.fallback { + FallbackRoute::Default(_) => {} + FallbackRoute::Service(fallback) => { + req.extensions_mut() + .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); + } + } + + self.call_route(match_, req) + } Err( MatchError::NotFound | MatchError::ExtraTrailingSlash | MatchError::MissingTrailingSlash, - ) => self.fallback.clone().call(req), + ) => match &mut self.fallback { + FallbackRoute::Default(fallback) => { + if let Some(super_fallback) = req.extensions_mut().remove::>() + { + let mut super_fallback = super_fallback.0.into_inner(); + super_fallback.call(req) + } else { + fallback.call(req) + } + } + FallbackRoute::Service(fallback) => fallback.call(req), + }, } } } + +struct SuperFallback(SyncWrapper>); diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 4da166baea..b9993ab806 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -1,4 +1,5 @@ use super::*; +use crate::middleware::{map_request, map_response}; #[tokio::test] async fn basic() { @@ -58,3 +59,102 @@ async fn fallback_accessing_state() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "state"); } + +async fn inner_fallback() -> impl IntoResponse { + (StatusCode::NOT_FOUND, "inner") +} + +async fn outer_fallback() -> impl IntoResponse { + (StatusCode::NOT_FOUND, "outer") +} + +#[tokio::test] +async fn nested_router_inherits_fallback() { + let inner = Router::new(); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn doesnt_inherit_fallback_if_overriden() { + let inner = Router::new().fallback(inner_fallback); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "inner"); +} + +#[tokio::test] +async fn deeply_nested_inherit_from_top() { + let app = Router::new() + .nest("/foo", Router::new().nest("/bar", Router::new())) + .fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn deeply_nested_inherit_from_middle() { + let app = Router::new().nest( + "/foo", + Router::new() + .nest("/bar", Router::new()) + .fallback(outer_fallback), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn with_middleware_on_inner_fallback() { + async fn never_called(_: Request) -> Request { + panic!("should never be called") + } + + let inner = Router::new().layer(map_request(never_called)); + let app = Router::new().nest("/foo", inner).fallback(outer_fallback); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn also_inherits_default_layered_fallback() { + async fn set_header(mut res: Response) -> Response { + res.headers_mut() + .insert("x-from-fallback", "1".parse().unwrap()); + res + } + + let inner = Router::new(); + let app = Router::new() + .nest("/foo", inner) + .fallback(outer_fallback) + .layer(map_response(set_header)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.headers()["x-from-fallback"], "1"); + assert_eq!(res.text().await, "outer"); +} diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 5fd3324787..29e8286076 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -351,23 +351,6 @@ async fn nest_with_and_without_trailing() { assert_eq!(res.status(), StatusCode::OK); } -#[tokio::test] -async fn doesnt_call_outer_fallback() { - let app = Router::new() - .nest("/foo", Router::new().route("/", get(|| async {}))) - .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); - - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - - let res = client.get("/foo/not-found").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - // the default fallback returns an empty body - assert_eq!(res.text().await, ""); -} - #[tokio::test] async fn nesting_with_root_inner_router() { let app = Router::new().nest(