diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index bbf51ef8a6..98b719c60d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **added:** Add `RouterService::{layer, route_layer}` ([#1550]) +- **added:** Add `Router::provide_state` and `MethodRouter::provide_state` [#1550]: https://github.com/tokio-rs/axum/pull/1550 diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 4b5d5b3fa5..0fc59c976e 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -749,6 +749,22 @@ where } } + /// Provide the state but keep the `MethodRouter`. + pub fn provide_state(self, state: S) -> MethodRouter { + MethodRouter { + get: self.get.provide_state(state.clone()), + head: self.head.provide_state(state.clone()), + delete: self.delete.provide_state(state.clone()), + options: self.options.provide_state(state.clone()), + patch: self.patch.provide_state(state.clone()), + post: self.post.provide_state(state.clone()), + put: self.put.provide_state(state.clone()), + trace: self.trace.provide_state(state.clone()), + allow_header: self.allow_header, + fallback: self.fallback.provide_state(state), + } + } + /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -1177,6 +1193,16 @@ where Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())), } } + + fn provide_state(self, state: S) -> MethodEndpoint { + match self { + MethodEndpoint::None => MethodEndpoint::None, + MethodEndpoint::Route(route) => MethodEndpoint::Route(route), + MethodEndpoint::BoxedHandler(handler) => { + MethodEndpoint::Route(handler.into_route(state)) + } + } + } } impl Clone for MethodEndpoint { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 2544ac4125..be7c179a02 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -400,6 +400,74 @@ where pub fn with_state(self, state: S) -> RouterService { RouterService::new(self, state) } + + /// Provide the state but keep the `Router`. + /// + /// This can be used to nest or merge routers with different state types: + /// + /// ```rust + /// use axum::{ + /// Router, + /// routing::get, + /// extract::State, + /// }; + /// + /// #[derive(Clone)] + /// struct StateOne {} + /// + /// let router_one = Router::new().route("/one", get(|_: State| async {})); + /// + /// #[derive(Clone)] + /// struct StateTwo {} + /// + /// let router_two = Router::new().route("/two", get(|_: State| async {})); + /// + /// #[derive(Clone)] + /// struct StateThree {} + /// + /// // our final router which requires `StateThree` + /// // the type annotations are just for clarity. Rust can infer them. + /// let app = Router::::new() + /// // provide the state such that the router can be nested + /// .nest("/one", router_one.provide_state(StateOne {})) + /// // same for merge + /// .merge(router_two.provide_state(StateTwo {})) + /// // we can still add routes that requires `StateThree` + /// .route("/three", get(|_: State| async {})) + /// // provide the final state + /// .with_state(StateThree {}); + /// # let _: axum::routing::RouterService = app; + /// ``` + /// + /// This is necessary because [`Router::nest`] and [`Router::merge`] both requires arguments of + /// type `Router`. If we used [`Router::with_state`] we'd get a [`RouterService`] which + /// wouldn't work. `Router::provide_state` maintains the `Router` type. + pub fn provide_state(self, state: S) -> Router { + let routes = self + .routes + .into_iter() + .map(|(id, endpoint)| { + let endpoint: Endpoint = match endpoint { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.provide_state(state.clone())) + } + Endpoint::Route(route) => Endpoint::Route(route), + Endpoint::NestedRouter(router) => { + Endpoint::Route(router.into_route(state.clone())) + } + }; + (id, endpoint) + }) + .collect(); + + let fallback = self.fallback.provide_state(state); + + Router { + routes, + node: self.node, + fallback, + } + } } impl Router<(), B> @@ -539,6 +607,14 @@ where Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)), } } + + fn provide_state(self, state: S) -> Fallback { + match self { + Fallback::Default(route) => Fallback::Default(route), + Fallback::Service(route) => Fallback::Service(route), + Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), + } + } } impl Clone for Fallback { diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index abad660a33..9d3d7d6f24 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -1,5 +1,8 @@ use super::*; -use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json}; +use crate::{ + error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json, + RouterService, +}; use serde_json::{json, Value}; use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; @@ -397,3 +400,25 @@ async fn middleware_that_return_early() { ); assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } + +#[tokio::test] +async fn merge_router_with_different_state() { + #[derive(Clone)] + struct A; + + #[derive(Clone)] + struct B; + + let router_a = Router::new().route("/", get(|_: State| async { "get" })); + let router_b = Router::new().route("/", post(|_: State| async { "post" })); + + let app = Router::new() + .merge(router_a) + .merge(router_b.provide_state(B)) + .with_state(A); + + let client = TestClient::from_service(app); + + assert_eq!(client.get("/").send().await.text().await, "get"); + assert_eq!(client.post("/").send().await.text().await, "post"); +}