Skip to content

Commit

Permalink
Add provide_state
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Nov 20, 2022
1 parent b816ac7 commit f1b935b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 1 deletion.
1 change: 1 addition & 0 deletions axum/CHANGELOG.md
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased

- **added:** Add `RouterService::{layer, route_layer}` ([#1550])
- **added:** Add `Router::provide_state` and `MethodRouter::provide_state`

[#1550]: https://github.com/tokio-rs/axum/pull/1550

Expand Down
26 changes: 26 additions & 0 deletions axum/src/routing/method_routing.rs
Expand Up @@ -749,6 +749,22 @@ where
}
}

/// Provide the state but keep the `MethodRouter`.
pub fn provide_state<S2>(self, state: S) -> MethodRouter<S2, B, E> {
MethodRouter {
get: self.get.provide_state(state.clone()),
head: self.head.provide_state(state.clone()),
delete: self.delete.provide_state(state.clone()),
options: self.options.provide_state(state.clone()),
patch: self.patch.provide_state(state.clone()),
post: self.post.provide_state(state.clone()),
put: self.put.provide_state(state.clone()),
trace: self.trace.provide_state(state.clone()),
allow_header: self.allow_header,
fallback: self.fallback.provide_state(state),
}
}

/// Chain an additional service that will accept requests matching the given
/// `MethodFilter`.
///
Expand Down Expand Up @@ -1177,6 +1193,16 @@ where
Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())),
}
}

fn provide_state<S2>(self, state: S) -> MethodEndpoint<S2, B, E> {
match self {
MethodEndpoint::None => MethodEndpoint::None,
MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
MethodEndpoint::BoxedHandler(handler) => {
MethodEndpoint::Route(handler.into_route(state))
}
}
}
}

impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
Expand Down
76 changes: 76 additions & 0 deletions axum/src/routing/mod.rs
Expand Up @@ -400,6 +400,74 @@ where
pub fn with_state(self, state: S) -> RouterService<B> {
RouterService::new(self, state)
}

/// Provide the state but keep the `Router`.
///
/// This can be used to nest or merge routers with different state types:
///
/// ```rust
/// use axum::{
/// Router,
/// routing::get,
/// extract::State,
/// };
///
/// #[derive(Clone)]
/// struct StateOne {}
///
/// let router_one = Router::new().route("/one", get(|_: State<StateOne>| async {}));
///
/// #[derive(Clone)]
/// struct StateTwo {}
///
/// let router_two = Router::new().route("/two", get(|_: State<StateTwo>| async {}));
///
/// #[derive(Clone)]
/// struct StateThree {}
///
/// // our final router which requires `StateThree`
/// // the type annotations are just for clarity. Rust can infer them.
/// let app = Router::<StateThree>::new()
/// // provide the state such that the router can be nested
/// .nest("/one", router_one.provide_state(StateOne {}))
/// // same for merge
/// .merge(router_two.provide_state(StateTwo {}))
/// // we can still add routes that requires `StateThree`
/// .route("/three", get(|_: State<StateThree>| async {}))
/// // provide the final state
/// .with_state(StateThree {});
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// This is necessary because [`Router::nest`] and [`Router::merge`] both requires arguments of
/// type `Router`. If we used [`Router::with_state`] we'd get a [`RouterService`] which
/// wouldn't work. `Router::provide_state` maintains the `Router` type.
pub fn provide_state<S2>(self, state: S) -> Router<S2, B> {
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let endpoint: Endpoint<S2, B> = match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.provide_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
Endpoint::NestedRouter(router) => {
Endpoint::Route(router.into_route(state.clone()))
}
};
(id, endpoint)
})
.collect();

let fallback = self.fallback.provide_state(state);

Router {
routes,
node: self.node,
fallback,
}
}
}

impl<B> Router<(), B>
Expand Down Expand Up @@ -539,6 +607,14 @@ where
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
}
}

fn provide_state<S2>(self, state: S) -> Fallback<S2, B, E> {
match self {
Fallback::Default(route) => Fallback::Default(route),
Fallback::Service(route) => Fallback::Service(route),
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
}
}
}

impl<S, B, E> Clone for Fallback<S, B, E> {
Expand Down
27 changes: 26 additions & 1 deletion axum/src/routing/tests/merge.rs
@@ -1,5 +1,8 @@
use super::*;
use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json};
use crate::{
error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json,
RouterService,
};
use serde_json::{json, Value};
use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};

Expand Down Expand Up @@ -397,3 +400,25 @@ async fn middleware_that_return_early() {
);
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
}

#[tokio::test]
async fn merge_router_with_different_state() {
#[derive(Clone)]
struct A;

#[derive(Clone)]
struct B;

let router_a = Router::new().route("/", get(|_: State<A>| async { "get" }));
let router_b = Router::new().route("/", post(|_: State<B>| async { "post" }));

let app = Router::new()
.merge(router_a)
.merge(router_b.provide_state(B))
.with_state(A);

let client = TestClient::from_service(app);

assert_eq!(client.get("/").send().await.text().await, "get");
assert_eq!(client.post("/").send().await.text().await, "post");
}

0 comments on commit f1b935b

Please sign in to comment.