diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 9fe6469a08..aebf1cf4e5 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -68,7 +68,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/set", post(set_secret)) /// .route("/get", get(get_secret)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// /// If you have been using `Arc` you cannot implement `FromRef> for Key`. diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index edc601a361..462c65a1f5 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -86,7 +86,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/sessions", post(create_session)) /// .route("/me", get(me)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// If you have been using `Arc` you cannot implement `FromRef> for Key`. /// You can use a new type instead: diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 25d566435a..19c6236a83 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -147,7 +147,7 @@ impl From> for Router { mod tests { #[allow(unused_imports)] use super::*; - use axum::{extract::Path, http::Method, routing::RouterService, Router}; + use axum::{extract::Path, http::Method, Router}; use http::Request; use tower::{Service, ServiceExt}; @@ -162,7 +162,7 @@ mod tests { .update(|Path(id): Path| async move { format!("users#update id={}", id) }) .destroy(|Path(id): Path| async move { format!("users#destroy id={}", id) }); - let mut app = Router::new().merge(users).into_service(); + let mut app = Router::new().merge(users); assert_eq!( call_route(&mut app, Method::GET, "/users").await, @@ -205,7 +205,7 @@ mod tests { ); } - async fn call_route(app: &mut RouterService, method: Method, uri: &str) -> String { + async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { let res = app .ready() .await diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 5fc6d8527b..ad8d8af135 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -270,7 +270,7 @@ mod tests { #[allow(dead_code)] fn works_with_router_with_state() { - let _: axum::RouterService = Router::new() + let _: Router = Router::new() .merge(SpaRouter::new("/assets", "test_files")) .route("/", get(|_: axum::extract::State| async {})) .with_state(String::new()); diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 95f400aab9..053f2b0d3f 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -606,7 +606,7 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream { /// let app = Router::new() /// .route("/", get(handler).post(other_handler)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// /// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html diff --git a/axum-macros/tests/from_ref/pass/basic.rs b/axum-macros/tests/from_ref/pass/basic.rs index 055d63b8d9..e410e11a05 100644 --- a/axum-macros/tests/from_ref/pass/basic.rs +++ b/axum-macros/tests/from_ref/pass/basic.rs @@ -14,7 +14,7 @@ fn main() { auth_token: Default::default(), }; - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/", get(handler)) .with_state(state); } diff --git a/axum-macros/tests/from_request/pass/state_enum_via.rs b/axum-macros/tests/from_request/pass/state_enum_via.rs index 8da81901e1..632adb56ca 100644 --- a/axum-macros/tests/from_request/pass/state_enum_via.rs +++ b/axum-macros/tests/from_request/pass/state_enum_via.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/a", get(|_: AppState| async {})) .route("/b", get(|_: InnerState| async {})) .with_state(AppState::default()); diff --git a/axum-macros/tests/from_request/pass/state_enum_via_parts.rs b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs index 3d5b5b0bc3..664ff3ab4d 100644 --- a/axum-macros/tests/from_request/pass/state_enum_via_parts.rs +++ b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequestParts; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/a", get(|_: AppState| async {})) .route("/b", get(|_: InnerState| async {})) .route("/c", get(|_: AppState, _: InnerState| async {})) diff --git a/axum-macros/tests/from_request/pass/state_explicit.rs b/axum-macros/tests/from_request/pass/state_explicit.rs index bea2958d7b..aed9dad6d5 100644 --- a/axum-macros/tests/from_request/pass/state_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_explicit.rs @@ -6,7 +6,7 @@ use axum::{ }; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/b", get(|_: Extractor| async {})) .with_state(AppState::default()); } diff --git a/axum-macros/tests/from_request/pass/state_explicit_parts.rs b/axum-macros/tests/from_request/pass/state_explicit_parts.rs index a4865fc85b..94f37cf6b8 100644 --- a/axum-macros/tests/from_request/pass/state_explicit_parts.rs +++ b/axum-macros/tests/from_request/pass/state_explicit_parts.rs @@ -7,7 +7,7 @@ use axum::{ use std::collections::HashMap; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/b", get(|_: Extractor| async {})) .with_state(AppState::default()); } diff --git a/axum-macros/tests/from_request/pass/state_field_explicit.rs b/axum-macros/tests/from_request/pass/state_field_explicit.rs index 1caccf461f..b6d003dc00 100644 --- a/axum-macros/tests/from_request/pass/state_field_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_field_explicit.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/", get(|_: Extractor| async {})) .with_state(AppState::default()); } diff --git a/axum-macros/tests/from_request/pass/state_field_infer.rs b/axum-macros/tests/from_request/pass/state_field_infer.rs index 08884dcb9f..a24861a162 100644 --- a/axum-macros/tests/from_request/pass/state_field_infer.rs +++ b/axum-macros/tests/from_request/pass/state_field_infer.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/", get(|_: Extractor| async {})) .with_state(AppState::default()); } diff --git a/axum-macros/tests/from_request/pass/state_via.rs b/axum-macros/tests/from_request/pass/state_via.rs index 590ec53564..7d196b0395 100644 --- a/axum-macros/tests/from_request/pass/state_via.rs +++ b/axum-macros/tests/from_request/pass/state_via.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/b", get(|_: (), _: AppState| async {})) .route("/c", get(|_: (), _: InnerState| async {})) .with_state(AppState::default()); diff --git a/axum-macros/tests/from_request/pass/state_via_infer.rs b/axum-macros/tests/from_request/pass/state_via_infer.rs index 50b4a1bc1c..40c52d8d4d 100644 --- a/axum-macros/tests/from_request/pass/state_via_infer.rs +++ b/axum-macros/tests/from_request/pass/state_via_infer.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/b", get(|_: AppState| async {})) .with_state(AppState::default()); } diff --git a/axum-macros/tests/from_request/pass/state_via_parts.rs b/axum-macros/tests/from_request/pass/state_via_parts.rs index 2817e8327b..44da20dbf0 100644 --- a/axum-macros/tests/from_request/pass/state_via_parts.rs +++ b/axum-macros/tests/from_request/pass/state_via_parts.rs @@ -6,7 +6,7 @@ use axum::{ use axum_macros::FromRequestParts; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/a", get(|_: AppState, _: InnerState, _: String| async {})) .route("/b", get(|_: AppState, _: String| async {})) .route("/c", get(|_: InnerState, _: String| async {})) diff --git a/axum-macros/tests/from_request/pass/state_with_rejection.rs b/axum-macros/tests/from_request/pass/state_with_rejection.rs index aef3d9c773..9921add02b 100644 --- a/axum-macros/tests/from_request/pass/state_with_rejection.rs +++ b/axum-macros/tests/from_request/pass/state_with_rejection.rs @@ -8,7 +8,7 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: axum::routing::RouterService = Router::new() + let _: axum::Router = Router::new() .route("/a", get(|_: Extractor| async {})) .with_state(AppState::default()); } diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index bbf51ef8a6..7d257c680f 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- **added:** Add `RouterService::{layer, route_layer}` ([#1550]) +- **breaking:** `RouterService` has been removed since `Router` now implements + `Service` when the state is `()`. Use `Router::with_state` to provide the + state and get a `Router<()>` ([#1552]) -[#1550]: https://github.com/tokio-rs/axum/pull/1550 +[#1552]: https://github.com/tokio-rs/axum/pull/1552 # 0.6.0-rc.5 (18. November, 2022) diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index 0ff269f003..3a7dd998da 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -1,7 +1,7 @@ use axum::{ extract::State, routing::{get, post}, - Extension, Json, Router, RouterService, Server, + Extension, Json, Router, Server, }; use hyper::server::conn::AddrIncoming; use serde::{Deserialize, Serialize}; @@ -17,13 +17,9 @@ fn main() { ensure_rewrk_is_installed(); } - benchmark("minimal").run(|| Router::new().into_service()); + benchmark("minimal").run(Router::new); - benchmark("basic").run(|| { - Router::new() - .route("/", get(|| async { "Hello, World!" })) - .into_service() - }); + benchmark("basic").run(|| Router::new().route("/", get(|| async { "Hello, World!" }))); benchmark("routing").path("/foo/bar/baz").run(|| { let mut app = Router::new(); @@ -34,32 +30,26 @@ fn main() { } } } - app.route("/foo/bar/baz", get(|| async {})).into_service() + app.route("/foo/bar/baz", get(|| async {})) }); benchmark("receive-json") .method("post") .headers(&[("content-type", "application/json")]) .body(r#"{"n": 123, "s": "hi there", "b": false}"#) - .run(|| { - Router::new() - .route("/", post(|_: Json| async {})) - .into_service() - }); + .run(|| Router::new().route("/", post(|_: Json| async {}))); benchmark("send-json").run(|| { - Router::new() - .route( - "/", - get(|| async { - Json(Payload { - n: 123, - s: "hi there".to_owned(), - b: false, - }) - }), - ) - .into_service() + Router::new().route( + "/", + get(|| async { + Json(Payload { + n: 123, + s: "hi there".to_owned(), + b: false, + }) + }), + ) }); let state = AppState { @@ -75,7 +65,6 @@ fn main() { Router::new() .route("/", get(|_: Extension| async {})) .layer(Extension(state.clone())) - .into_service() }); benchmark("state").run(|| { @@ -133,7 +122,7 @@ impl BenchmarkBuilder { fn run(self, f: F) where - F: FnOnce() -> RouterService, + F: FnOnce() -> Router<()>, { // support only running some benchmarks with // ``` diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index 3612d211db..6aaea39a66 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -1,6 +1,14 @@ use std::{convert::Infallible, fmt}; -use crate::{body::HttpBody, handler::Handler, routing::Route, Router}; +use http::Request; +use tower::Service; + +use crate::{ + body::HttpBody, + handler::Handler, + routing::{future::RouteFuture, Route}, + Router, +}; pub(crate) struct BoxedIntoRoute(Box>); @@ -13,6 +21,7 @@ where where H: Handler, T: 'static, + B: HttpBody, { Self(Box::new(MakeErasedHandler { handler, @@ -30,6 +39,14 @@ where into_route: |router, state| Route::new(router.with_state(state)), })) } + + pub(crate) fn call_with_state( + self, + request: Request, + state: S, + ) -> RouteFuture { + self.0.call_with_state(request, state) + } } impl BoxedIntoRoute { @@ -39,7 +56,7 @@ impl BoxedIntoRoute { B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { BoxedIntoRoute(Box::new(Map { @@ -69,6 +86,8 @@ pub(crate) trait ErasedIntoRoute: Send { fn clone_box(&self) -> Box>; fn into_route(self: Box, state: S) -> Route; + + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture; } pub(crate) struct MakeErasedHandler { @@ -80,7 +99,7 @@ impl ErasedIntoRoute for MakeErasedHandler where H: Clone + Send + 'static, S: 'static, - B: 'static, + B: HttpBody + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) @@ -89,6 +108,14 @@ where fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.handler, state) } + + fn call_with_state( + self: Box, + request: Request, + state: S, + ) -> RouteFuture { + self.into_route(state).call(request) + } } impl Clone for MakeErasedHandler @@ -110,8 +137,8 @@ pub(crate) struct MakeErasedRouter { impl ErasedIntoRoute for MakeErasedRouter where - S: Clone + Send + 'static, - B: 'static, + S: Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) @@ -120,6 +147,14 @@ where fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.router, state) } + + fn call_with_state( + mut self: Box, + request: Request, + state: S, + ) -> RouteFuture { + self.router.call_with_state(request, state) + } } impl Clone for MakeErasedRouter @@ -144,7 +179,7 @@ where S: 'static, B: 'static, E: 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { fn clone_box(&self) -> Box> { @@ -157,6 +192,10 @@ where fn into_route(self: Box, state: S) -> Route { (self.layer)(self.inner.into_route(state)) } + + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { + (self.layer)(self.inner.into_route(state)).call(request) + } } pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 603fac78e8..8ae44ff158 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -466,7 +466,7 @@ let app = Router::new() .route("/", get(handler)) .layer(MyLayer { state: state.clone() }) .with_state(state); -# let _: axum::routing::RouterService = app; +# let _: axum::Router = app; ``` # Passing state from middleware to handlers @@ -556,7 +556,7 @@ async fn rewrite_request_uri(req: Request, next: Next) -> Response { // this can be any `tower::Layer` let middleware = axum::middleware::from_fn(rewrite_request_uri); -let app = Router::new().into_service(); +let app = Router::new(); // apply the layer around the whole `Router` // this way the middleware will run before `Router` receives the request diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 86165361e5..05ee750c56 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -2,10 +2,6 @@ Convert this router into a [`MakeService`], that will store `C`'s associated `ConnectInfo` in a request extension such that [`ConnectInfo`] can extract it. -This is a convenience method for routers that don't have any state (i.e. the -state type is `()`). Use [`RouterService::into_make_service_with_connect_info`] -otherwise. - This enables extracting things like the client's remote address. Extracting [`std::net::SocketAddr`] is supported out of the box: diff --git a/axum/src/docs/routing/merge.md b/axum/src/docs/routing/merge.md index 5d2b94bebd..0e103c83ca 100644 --- a/axum/src/docs/routing/merge.md +++ b/axum/src/docs/routing/merge.md @@ -39,13 +39,39 @@ let app = Router::new() # Merging routers with state -When combining [`Router`]s with this function, each [`Router`] must have the -same type of state. See ["Combining stateful routers"][combining-stateful-routers] -for details. +When combining [`Router`]s with this method, each [`Router`] must have the +same type of state. If your routers have different types you can use +[`Router::with_state`] to provide the state and make the types match: + +```rust +use axum::{ + Router, + routing::get, + extract::State, +}; + +#[derive(Clone)] +struct InnerState {} + +#[derive(Clone)] +struct OuterState {} + +async fn inner_handler(state: State) {} + +let inner_router = Router::new() + .route("/bar", get(inner_handler)) + .with_state(InnerState {}); + +async fn outer_handler(state: State) {} + +let app = Router::new() + .route("/", get(outer_handler)) + .merge(inner_router) + .with_state(OuterState {}); +# let _: axum::Router = app; +``` # Panics - If two routers that each have a [fallback](Router::fallback) are merged. This is because `Router` only allows a single fallback. - -[combining-stateful-routers]: crate::extract::State#combining-stateful-routers diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index d1a8dafce8..b40d0fc951 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -149,12 +149,40 @@ Here requests like `GET /api/not-found` will go to `api_fallback`. # Nesting routers with state -When combining [`Router`]s with this function, each [`Router`] must have the -same type of state. See ["Combining stateful routers"][combining-stateful-routers] -for details. +When combining [`Router`]s with this method, each [`Router`] must have the +same type of state. If your routers have different types you can use +[`Router::with_state`] to provide the state and make the types match: -If you want to compose axum services with different types of state, use -[`Router::nest_service`]. +```rust +use axum::{ + Router, + routing::get, + extract::State, +}; + +#[derive(Clone)] +struct InnerState {} + +#[derive(Clone)] +struct OuterState {} + +async fn inner_handler(state: State) {} + +let inner_router = Router::new() + .route("/bar", get(inner_handler)) + .with_state(InnerState {}); + +async fn outer_handler(state: State) {} + +let app = Router::new() + .route("/", get(outer_handler)) + .nest("/foo", inner_router) + .with_state(OuterState {}); +# let _: axum::Router = app; +``` + +Note that the inner router will still inherit the fallback from the outer +router. # Panics @@ -165,4 +193,3 @@ for more details. [`OriginalUri`]: crate::extract::OriginalUri [fallbacks]: Router::fallback -[combining-stateful-routers]: crate::extract::State#combining-stateful-routers diff --git a/axum/src/docs/routing/route_service.md b/axum/src/docs/routing/route_service.md index 623c6cb628..a14323a933 100644 --- a/axum/src/docs/routing/route_service.md +++ b/axum/src/docs/routing/route_service.md @@ -69,7 +69,7 @@ use axum::{routing::get, Router}; let app = Router::new().route_service( "/", - Router::new().route("/foo", get(|| async {})).into_service(), + Router::new().route("/foo", get(|| async {})), ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 3563cef110..447935d86e 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -43,7 +43,7 @@ use std::{ /// ) { /// // use `state`... /// } -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// /// ## Combining stateful routers @@ -71,19 +71,19 @@ use std::{ /// async fn posts_handler(State(state): State) { /// // use `state`... /// } -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// /// However, if you are composing [`Router`]s that are defined in separate scopes, /// you may need to annotate the [`State`] type explicitly: /// /// ``` -/// use axum::{Router, RouterService, routing::get, extract::State}; +/// use axum::{Router, routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// -/// fn make_app() -> RouterService { +/// fn make_app() -> Router { /// let state = AppState {}; /// /// Router::new() @@ -101,19 +101,15 @@ use std::{ /// async fn posts_handler(State(state): State) { /// // use `state`... /// } -/// # let _: axum::routing::RouterService = make_app(); +/// # let _: axum::Router = make_app(); /// ``` /// /// In short, a [`Router`]'s generic state type defaults to `()` /// (no state) unless [`Router::with_state`] is called or the value /// of the generic type is given explicitly. /// -/// It's also possible to combine multiple axum services with different state -/// types. See [`Router::nest_service`] for details. -/// /// [`Router`]: crate::Router /// [`Router::merge`]: crate::Router::merge -/// [`Router::nest_service`]: crate::Router::nest_service /// [`Router::nest`]: crate::Router::nest /// [`Router::with_state`]: crate::Router::with_state /// @@ -209,7 +205,7 @@ use std::{ /// State(state): State, /// ) { /// } -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` /// /// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`. diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index a52f09f237..d80f2507f6 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -358,7 +358,7 @@ mod tests { format!("you said: {}", body) } - let client = TestClient::from_service(handle.into_service()); + let client = TestClient::new(handle.into_service()); let res = client.post("/").body("hi there!").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -381,7 +381,7 @@ mod tests { .layer(MapRequestBodyLayer::new(body::boxed)) .with_state("foo"); - let client = TestClient::from_service(svc); + let client = TestClient::new(svc); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } diff --git a/axum/src/lib.rs b/axum/src/lib.rs index f73d279b13..314f844ff0 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -476,7 +476,7 @@ pub use self::extension::Extension; #[cfg(feature = "json")] pub use self::json::Json; #[doc(inline)] -pub use self::routing::{Router, RouterService}; +pub use self::routing::Router; #[doc(inline)] #[cfg(feature = "headers")] diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index e9a525ee8e..e15988ed62 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -137,7 +137,7 @@ pub fn from_fn(f: F) -> FromFnLayer { /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { FromFnLayer { @@ -381,7 +381,6 @@ mod tests { .layer(from_fn(insert_header)); let res = app - .into_service() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); diff --git a/axum/src/middleware/map_request.rs b/axum/src/middleware/map_request.rs index eeac81610a..6b17f7a0d4 100644 --- a/axum/src/middleware/map_request.rs +++ b/axum/src/middleware/map_request.rs @@ -152,7 +152,7 @@ pub fn map_request(f: F) -> MapRequestLayer { /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_request_with_state(state.clone(), my_middleware)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` pub fn map_request_with_state(state: S, f: F) -> MapRequestLayer { MapRequestLayer { diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs index 3be34ab6b6..67a5110a4d 100644 --- a/axum/src/middleware/map_response.rs +++ b/axum/src/middleware/map_response.rs @@ -136,7 +136,7 @@ pub fn map_response(f: F) -> MapResponseLayer { /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_response_with_state(state.clone(), my_middleware)) /// .with_state(state); -/// # let _: axum::routing::RouterService = app; +/// # let _: axum::Router = app; /// ``` pub fn map_response_with_state(state: S, f: F) -> MapResponseLayer { MapResponseLayer { diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 4b5d5b3fa5..4e2dae3505 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1,6 +1,6 @@ //! Route to services and handlers based on HTTP methods. -use super::{FallbackRoute, IntoMakeService}; +use super::IntoMakeService; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -83,7 +83,7 @@ macro_rules! top_level_service_fn { T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { on_service(MethodFilter::$method, svc) @@ -143,7 +143,7 @@ macro_rules! top_level_handler_fn { pub fn $name(handler: H) -> MethodRouter where H: Handler, - B: Send + 'static, + B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { @@ -327,7 +327,7 @@ where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { MethodRouter::new().on_service(filter, svc) @@ -391,7 +391,7 @@ where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { MethodRouter::new() @@ -430,7 +430,7 @@ top_level_handler_fn!(trace, TRACE); pub fn on(filter: MethodFilter, handler: H) -> MethodRouter where H: Handler, - B: Send + 'static, + B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { @@ -477,7 +477,7 @@ where pub fn any(handler: H) -> MethodRouter where H: Handler, - B: Send + 'static, + B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { @@ -571,7 +571,7 @@ impl fmt::Debug for MethodRouter { impl MethodRouter where - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { /// Chain an additional handler that will accept requests matching the given @@ -633,7 +633,7 @@ where impl MethodRouter<(), B, Infallible> where - B: Send + 'static, + B: HttpBody + Send + 'static, { /// Convert the handler into a [`MakeService`]. /// @@ -665,7 +665,7 @@ where /// /// [`MakeService`]: tower::make::MakeService pub fn into_make_service(self) -> IntoMakeService { - IntoMakeService::new(self) + IntoMakeService::new(self.with_state(())) } /// Convert the router into a [`MakeService`] which stores information @@ -701,13 +701,13 @@ where /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[cfg(feature = "tokio")] pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { - IntoMakeServiceWithConnectInfo::new(self) + IntoMakeServiceWithConnectInfo::new(self.with_state(())) } } impl MethodRouter where - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all @@ -731,21 +731,19 @@ where } } - /// Provide the state. - /// - /// See [`State`](crate::extract::State) for more details about accessing state. - pub fn with_state(self, state: S) -> WithState { - WithState { - get: self.get.into_route(&state), - head: self.head.into_route(&state), - delete: self.delete.into_route(&state), - options: self.options.into_route(&state), - patch: self.patch.into_route(&state), - post: self.post.into_route(&state), - put: self.put.into_route(&state), - trace: self.trace.into_route(&state), - fallback: self.fallback.into_fallback_route(&state), + /// Provide the state for the router. + pub fn with_state(self, state: S) -> MethodRouter { + MethodRouter { + get: self.get.with_state(state.clone()), + head: self.head.with_state(state.clone()), + delete: self.delete.with_state(state.clone()), + options: self.options.with_state(state.clone()), + patch: self.patch.with_state(state.clone()), + post: self.post.with_state(state.clone()), + put: self.put.with_state(state.clone()), + trace: self.trace.with_state(state.clone()), allow_header: self.allow_header, + fallback: self.fallback.with_state(state), } } @@ -918,10 +916,7 @@ where } #[doc = include_str!("../docs/method_routing/layer.md")] - pub fn layer( - self, - layer: L, - ) -> MethodRouter + pub fn layer(self, layer: L) -> MethodRouter where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, @@ -930,6 +925,8 @@ where >>::Future: Send + 'static, E: 'static, S: 'static, + NewReqBody: HttpBody + 'static, + NewError: 'static, { let layer_fn = move |route: Route| route.layer(layer.clone()); @@ -1069,6 +1066,74 @@ where self.allow_header = AllowHeader::Skip; self } + + pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + macro_rules! call { + ( + $req:expr, + $method:expr, + $method_variant:ident, + $svc:expr + ) => { + if $method == Method::$method_variant { + match $svc { + MethodEndpoint::None => {} + MethodEndpoint::Route(route) => { + return RouteFuture::from_future(route.oneshot_inner($req)) + .strip_body($method == Method::HEAD); + } + MethodEndpoint::BoxedHandler(handler) => { + let mut route = handler.clone().into_route(state); + return RouteFuture::from_future(route.oneshot_inner($req)) + .strip_body($method == Method::HEAD); + } + } + } + }; + } + + let method = req.method().clone(); + + // written with a pattern match like this to ensure we call all routes + let Self { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + allow_header, + } = self; + + call!(req, method, HEAD, head); + call!(req, method, HEAD, get); + call!(req, method, GET, get); + call!(req, method, POST, post); + call!(req, method, OPTIONS, options); + call!(req, method, PATCH, patch); + call!(req, method, PUT, put); + call!(req, method, DELETE, delete); + call!(req, method, TRACE, trace); + + let future = match fallback { + Fallback::Default(route) | Fallback::Service(route) => { + RouteFuture::from_future(route.oneshot_inner(req)) + } + Fallback::BoxedHandler(handler) => { + let mut route = handler.clone().into_route(state); + RouteFuture::from_future(route.oneshot_inner(req)) + } + }; + + match allow_header { + AllowHeader::None => future.allow_header(Bytes::new()), + AllowHeader::Skip => future, + AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()), + } + } } fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { @@ -1091,24 +1156,6 @@ fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { } } -impl Service> for MethodRouter<(), B, E> -where - B: HttpBody + Send + 'static, -{ - type Response = Response; - type Error = E; - type Future = RouteFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - self.clone().with_state(()).call(req) - } -} - impl Clone for MethodRouter { fn clone(&self) -> Self { Self { @@ -1128,7 +1175,7 @@ impl Clone for MethodRouter { impl Default for MethodRouter where - B: Send + 'static, + B: HttpBody + Send + 'static, S: Clone, { fn default() -> Self { @@ -1160,7 +1207,7 @@ where B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { match self { @@ -1170,11 +1217,13 @@ where } } - fn into_route(self, state: &S) -> Option> { + fn with_state(self, state: S) -> MethodEndpoint { match self { - Self::None => None, - Self::Route(route) => Some(route), - Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())), + MethodEndpoint::None => MethodEndpoint::None, + MethodEndpoint::Route(route) => MethodEndpoint::Route(route), + MethodEndpoint::BoxedHandler(handler) => { + MethodEndpoint::Route(handler.into_route(state)) + } } } } @@ -1199,85 +1248,9 @@ impl fmt::Debug for MethodEndpoint { } } -/// A [`MethodRouter`] which has access to some state. -/// -/// Implements [`Service`]. -/// -/// The state can be extracted with [`State`](crate::extract::State). -/// -/// Created with [`MethodRouter::with_state`] -pub struct WithState { - get: Option>, - head: Option>, - delete: Option>, - options: Option>, - patch: Option>, - post: Option>, - put: Option>, - trace: Option>, - fallback: FallbackRoute, - allow_header: AllowHeader, -} - -impl WithState { - /// Convert the handler into a [`MakeService`]. - /// - /// See [`MethodRouter::into_make_service`] for more details. - /// - /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService { - IntoMakeService::new(self) - } - - /// Convert the router into a [`MakeService`] which stores information - /// about the incoming connection. - /// - /// See [`MethodRouter::into_make_service_with_connect_info`] for more details. - /// - /// [`MakeService`]: tower::make::MakeService - #[cfg(feature = "tokio")] - pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { - IntoMakeServiceWithConnectInfo::new(self) - } -} - -impl Clone for WithState { - fn clone(&self) -> Self { - Self { - get: self.get.clone(), - head: self.head.clone(), - delete: self.delete.clone(), - options: self.options.clone(), - patch: self.patch.clone(), - post: self.post.clone(), - put: self.put.clone(), - trace: self.trace.clone(), - fallback: self.fallback.clone(), - allow_header: self.allow_header.clone(), - } - } -} - -impl fmt::Debug for WithState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WithState") - .field("get", &self.get) - .field("head", &self.head) - .field("delete", &self.delete) - .field("options", &self.options) - .field("patch", &self.patch) - .field("post", &self.post) - .field("put", &self.put) - .field("trace", &self.trace) - .field("fallback", &self.fallback) - .field("allow_header", &self.allow_header) - .finish() - } -} - -impl Service> for WithState +impl Service> for MethodRouter<(), B, E> where - B: HttpBody + Send, + B: HttpBody + Send + 'static, { type Response = Response; type Error = E; @@ -1288,56 +1261,9 @@ where Poll::Ready(Ok(())) } + #[inline] fn call(&mut self, req: Request) -> Self::Future { - macro_rules! call { - ( - $req:expr, - $method:expr, - $method_variant:ident, - $svc:expr - ) => { - if $method == Method::$method_variant { - if let Some(svc) = $svc { - return RouteFuture::from_future(svc.oneshot_inner($req)) - .strip_body($method == Method::HEAD); - } - } - }; - } - - let method = req.method().clone(); - - // written with a pattern match like this to ensure we call all routes - let Self { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - } = self; - - call!(req, method, HEAD, head); - call!(req, method, HEAD, get); - call!(req, method, GET, get); - call!(req, method, POST, post); - call!(req, method, OPTIONS, options); - call!(req, method, PATCH, patch); - call!(req, method, PUT, put); - call!(req, method, DELETE, delete); - call!(req, method, TRACE, trace); - - let future = RouteFuture::from_future(fallback.oneshot_inner(req)); - - match allow_header { - AllowHeader::None => future.allow_header(Bytes::new()), - AllowHeader::Skip => future, - AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()), - } + self.call_with_state(req, ()) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 2618cb66ec..367ced338b 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::{not_found::NotFound, strip_prefix::StripPrefix}; +use self::{future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -12,8 +12,14 @@ use crate::{ use axum_core::response::{IntoResponse, Response}; use http::Request; use matchit::MatchError; -use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc}; -use tower::util::{BoxCloneService, Oneshot}; +use std::{ + collections::HashMap, + convert::Infallible, + fmt, + sync::Arc, + task::{Context, Poll}, +}; +use sync_wrapper::SyncWrapper; use tower_layer::Layer; use tower_service::Service; @@ -27,14 +33,10 @@ mod route; mod strip_prefix; pub(crate) mod url_params; -mod service; #[cfg(test)] mod tests; -pub use self::{ - into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route, - service::RouterService, -}; +pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, @@ -168,10 +170,10 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - let service = match try_downcast::, _>(service) { + let service = match try_downcast::, _>(service) { Ok(_) => { panic!( - "Invalid route: `Router::route_service` cannot be used with `RouterService`s. \ + "Invalid route: `Router::route_service` cannot be used with `Router`s. \ Use `Router::nest` instead" ); } @@ -212,41 +214,6 @@ where } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. - /// - /// While [`nest`](Self::nest) requires [`Router`]s with the same type of - /// state, you can use this method to combine [`Router`]s with different - /// types of state: - /// - /// ``` - /// use axum::{ - /// Router, - /// routing::get, - /// extract::State, - /// }; - /// - /// #[derive(Clone)] - /// struct InnerState {} - /// - /// #[derive(Clone)] - /// struct OuterState {} - /// - /// async fn inner_handler(state: State) {} - /// - /// let inner_router = Router::new() - /// .route("/bar", get(inner_handler)) - /// .with_state(InnerState {}); - /// - /// async fn outer_handler(state: State) {} - /// - /// let app = Router::new() - /// .route("/", get(outer_handler)) - /// .nest_service("/foo", inner_router) - /// .with_state(OuterState {}); - /// # let _: axum::routing::RouterService = app; - /// ``` - /// - /// Note that the inner router will still inherit the fallback from the outer - /// router. #[track_caller] pub fn nest_service(self, path: &str, svc: T) -> Self where @@ -353,7 +320,7 @@ where >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, - NewReqBody: 'static, + NewReqBody: HttpBody + 'static, { let routes = self .routes @@ -429,11 +396,171 @@ where self } - /// Convert this router into a [`RouterService`] by providing the state. + /// Provide the state for the router. + /// + /// This method returns a router with a different state type. This can be used to nest or merge + /// routers with different state types. See [`Router::nest`] and [`Router::merge`] for more + /// details. + /// + /// # Implementing `Service` + /// + /// This can also be used to get a `Router` that implements [`Service`], since it only does so + /// when the state is `()`: + /// + /// ``` + /// use axum::{ + /// Router, + /// body::Body, + /// http::Request, + /// }; + /// use tower::{Service, ServiceExt}; + /// + /// #[derive(Clone)] + /// struct AppState {} + /// + /// // this router doesn't implement `Service` because its state isn't `()` + /// let router: Router = Router::new(); + /// + /// // by providing the state and setting the new state to `()`... + /// let router_service: Router<()> = router.with_state(AppState {}); + /// + /// // ...makes it implement `Service` + /// # async { + /// router_service.oneshot(Request::new(Body::empty())).await; + /// # }; + /// ``` /// - /// Once this method has been called you cannot add more routes. So it must be called as last. - pub fn with_state(self, state: S) -> RouterService { - RouterService::new(self, state) + /// # A note about performance + /// + /// If you need a `Router` that implements `Service` but you don't need any state (perhaps + /// you're making a library that uses axum internally) then it is recommended to call this + /// method before you start serving requests: + /// + /// ``` + /// use axum::{Router, routing::get}; + /// + /// let app = Router::new() + /// .route("/", get(|| async { /* ... */ })) + /// // even though we don't need any state, call `with_state(())` anyway + /// .with_state(()); + /// # let _: Router = app; + /// ``` + /// + /// This is not required but it gives axum a chance to update some internals in the router + /// which may impact performance and reduce allocations. + /// + /// Note that [`Router::into_make_service`] and [`Router::into_make_service_with_connect_info`] + /// do this automatically. + pub fn with_state(self, state: S) -> Router { + let routes = self + .routes + .into_iter() + .map(|(id, endpoint)| { + let endpoint: Endpoint = match endpoint { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.with_state(state.clone())) + } + Endpoint::Route(route) => Endpoint::Route(route), + Endpoint::NestedRouter(router) => { + Endpoint::Route(router.into_route(state.clone())) + } + }; + (id, endpoint) + }) + .collect(); + + let fallback = self.fallback.with_state(state); + + Router { + routes, + node: self.node, + fallback, + } + } + + pub(crate) fn call_with_state( + &mut self, + mut req: Request, + state: S, + ) -> RouteFuture { + #[cfg(feature = "original-uri")] + { + use crate::extract::OriginalUri; + + if req.extensions().get::().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } + } + + let path = req.uri().path().to_owned(); + + match self.node.at(&path) { + Ok(match_) => { + match &self.fallback { + Fallback::Default(_) => {} + Fallback::Service(fallback) => { + req.extensions_mut() + .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); + } + Fallback::BoxedHandler(fallback) => { + req.extensions_mut().insert(SuperFallback(SyncWrapper::new( + fallback.clone().into_route(state.clone()), + ))); + } + } + + self.call_route(match_, req, state) + } + Err( + MatchError::NotFound + | MatchError::ExtraTrailingSlash + | MatchError::MissingTrailingSlash, + ) => match &mut self.fallback { + Fallback::Default(fallback) => { + if let Some(super_fallback) = req.extensions_mut().remove::>() + { + let mut super_fallback = super_fallback.0.into_inner(); + super_fallback.call(req) + } else { + fallback.call(req) + } + } + Fallback::Service(fallback) => fallback.call(req), + Fallback::BoxedHandler(handler) => handler.clone().into_route(state).call(req), + }, + } + } + + #[inline] + fn call_route( + &self, + match_: matchit::Match<&RouteId>, + mut req: Request, + state: S, + ) -> RouteFuture { + let id = *match_.value; + + #[cfg(feature = "matched-path")] + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + req.extensions_mut(), + ); + + url_params::insert_url_params(req.extensions_mut(), match_.params); + + let endpont = self + .routes + .get(&id) + .expect("no route for id. This is a bug in axum. Please file an issue") + .clone(); + + match endpont { + Endpoint::MethodRouter(mut method_router) => method_router.call_with_state(req, state), + Endpoint::Route(mut route) => route.call(req), + Endpoint::NestedRouter(router) => router.call_with_state(req, state), + } } } @@ -441,16 +568,6 @@ impl Router<(), B> where B: HttpBody + Send + 'static, { - /// Convert this router into a [`RouterService`]. - /// - /// This is a convenience method for routers that don't have any state (i.e. the state type is - /// `()`). Use [`Router::with_state`] otherwise. - /// - /// Once this method has been called you cannot add more routes. So it must be called as last. - pub fn into_service(self) -> RouterService { - RouterService::new(self, ()) - } - /// Convert this router into a [`MakeService`], that is a [`Service`] whose /// response is another service. /// @@ -473,20 +590,38 @@ where /// # }; /// ``` /// - /// This is a convenience method for routers that don't have any state (i.e. the state type is - /// `()`). Use [`RouterService::into_make_service`] otherwise. - /// /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService> { - IntoMakeService::new(self.into_service()) + pub fn into_make_service(self) -> IntoMakeService { + // call `Router::with_state` such that everything is turned into `Route` eagerly + // rather than doing that per request + IntoMakeService::new(self.with_state(())) } #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")] #[cfg(feature = "tokio")] - pub fn into_make_service_with_connect_info( - self, - ) -> IntoMakeServiceWithConnectInfo, C> { - IntoMakeServiceWithConnectInfo::new(self.into_service()) + pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { + // call `Router::with_state` such that everything is turned into `Route` eagerly + // rather than doing that per request + IntoMakeServiceWithConnectInfo::new(self.with_state(())) + } +} + +impl Service> for Router<(), B> +where + B: HttpBody + Send + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = RouteFuture; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + self.call_with_state(req, ()) } } @@ -549,23 +684,13 @@ where } } - fn into_fallback_route(self, state: &S) -> FallbackRoute { - match self { - Self::Default(route) => FallbackRoute::Default(route), - Self::Service(route) => FallbackRoute::Service(route), - Self::BoxedHandler(handler) => { - FallbackRoute::Service(handler.into_route(state.clone())) - } - } - } - fn map(self, f: F) -> Fallback where S: 'static, B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { match self { @@ -574,6 +699,14 @@ where Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)), } } + + fn with_state(self, state: S) -> Fallback { + match self { + Fallback::Default(route) => Fallback::Default(route), + Fallback::Service(route) => Fallback::Service(route), + Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), + } + } } impl Clone for Fallback { @@ -596,61 +729,7 @@ impl fmt::Debug for Fallback { } } -/// Like `Fallback` but without the `S` param so it can be stored in `RouterService` -pub(crate) enum FallbackRoute { - Default(Route), - Service(Route), -} - -impl FallbackRoute { - fn layer(self, layer: L) -> FallbackRoute - where - L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, - NewReqBody: 'static, - NewError: 'static, - { - match self { - FallbackRoute::Default(route) => FallbackRoute::Default(route.layer(layer)), - FallbackRoute::Service(route) => FallbackRoute::Service(route.layer(layer)), - } - } -} - -impl fmt::Debug for FallbackRoute { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), - Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(), - } - } -} - -impl Clone for FallbackRoute { - fn clone(&self) -> Self { - match self { - Self::Default(inner) => Self::Default(inner.clone()), - Self::Service(inner) => Self::Service(inner.clone()), - } - } -} - -impl FallbackRoute { - pub(crate) fn oneshot_inner( - &mut self, - req: Request, - ) -> Oneshot, Response, E>, Request> { - match self { - FallbackRoute::Default(inner) => inner.oneshot_inner(req), - FallbackRoute::Service(inner) => inner.oneshot_inner(req), - } - } -} - -#[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine +#[allow(clippy::large_enum_variant)] enum Endpoint { MethodRouter(MethodRouter), Route(Route), @@ -662,14 +741,6 @@ where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { - fn into_route(self, state: S) -> Route { - match self { - Endpoint::MethodRouter(method_router) => Route::new(method_router.with_state(state)), - Endpoint::Route(route) => route, - Endpoint::NestedRouter(router) => router.into_route(state), - } - } - fn layer(self, layer: L) -> Endpoint where L: Layer> + Clone + Send + 'static, @@ -677,7 +748,7 @@ where >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, - NewReqBody: 'static, + NewReqBody: HttpBody + 'static, { match self { Endpoint::MethodRouter(method_router) => { @@ -721,6 +792,8 @@ enum RouterOrService { Service(T), } +struct SuperFallback(SyncWrapper>); + #[test] #[allow(warnings)] fn traits() { diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs deleted file mode 100644 index cbd655d8d4..0000000000 --- a/axum/src/routing/service.rs +++ /dev/null @@ -1,225 +0,0 @@ -use super::{ - future::RouteFuture, url_params, FallbackRoute, IntoMakeService, Node, Route, RouteId, Router, -}; -use crate::{ - body::{Body, HttpBody}, - response::Response, -}; -use axum_core::response::IntoResponse; -use http::Request; -use matchit::MatchError; -use std::{ - collections::HashMap, - convert::Infallible, - sync::Arc, - task::{Context, Poll}, -}; -use sync_wrapper::SyncWrapper; -use tower::Service; -use tower_layer::Layer; - -/// A [`Router`] converted into a [`Service`]. -#[derive(Debug)] -pub struct RouterService { - routes: HashMap>, - node: Arc, - fallback: FallbackRoute, -} - -impl RouterService -where - B: HttpBody + Send + 'static, -{ - pub(super) fn new(router: Router, state: S) -> Self - where - S: Clone + Send + Sync + 'static, - { - let fallback = router.fallback.into_fallback_route(&state); - - let routes = router - .routes - .into_iter() - .map(|(route_id, endpoint)| { - let route = endpoint.into_route(state.clone()); - (route_id, route) - }) - .collect(); - - Self { - routes, - node: router.node, - fallback, - } - } - - #[inline] - fn call_route( - &self, - match_: matchit::Match<&RouteId>, - mut req: Request, - ) -> RouteFuture { - let id = *match_.value; - - #[cfg(feature = "matched-path")] - crate::extract::matched_path::set_matched_path_for_request( - id, - &self.node.route_id_to_path, - req.extensions_mut(), - ); - - url_params::insert_url_params(req.extensions_mut(), match_.params); - - let mut route = self - .routes - .get(&id) - .expect("no route for id. This is a bug in axum. Please file an issue") - .clone(); - - route.call(req) - } - - /// Apply a [`tower::Layer`] to all routes in the router. - /// - /// See [`Router::layer`] for more details. - pub fn layer(self, layer: L) -> RouterService - where - L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, - NewReqBody: 'static, - { - let routes = self - .routes - .into_iter() - .map(|(id, route)| (id, route.layer(layer.clone()))) - .collect(); - - let fallback = self.fallback.layer(layer); - - RouterService { - routes, - node: self.node, - fallback, - } - } - - /// Apply a [`tower::Layer`] to the router that will only run if the request matches - /// a route. - /// - /// See [`Router::route_layer`] for more details. - pub fn route_layer(self, layer: L) -> Self - where - L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, - { - let routes = self - .routes - .into_iter() - .map(|(id, route)| (id, route.layer(layer.clone()))) - .collect(); - - Self { - routes, - node: self.node, - fallback: self.fallback, - } - } - - /// Convert the `RouterService` into a [`MakeService`]. - /// - /// See [`Router::into_make_service`] for more details. - /// - /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService { - IntoMakeService::new(self) - } - - /// Convert the `RouterService` into a [`MakeService`] which stores information - /// about the incoming connection. - /// - /// See [`Router::into_make_service_with_connect_info`] for more details. - /// - /// [`MakeService`]: tower::make::MakeService - #[cfg(feature = "tokio")] - pub fn into_make_service_with_connect_info( - self, - ) -> crate::extract::connect_info::IntoMakeServiceWithConnectInfo { - crate::extract::connect_info::IntoMakeServiceWithConnectInfo::new(self) - } -} - -impl Clone for RouterService { - fn clone(&self) -> Self { - Self { - routes: self.routes.clone(), - node: Arc::clone(&self.node), - fallback: self.fallback.clone(), - } - } -} - -impl Service> for RouterService -where - B: HttpBody + Send + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = RouteFuture; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&mut self, mut req: Request) -> Self::Future { - #[cfg(feature = "original-uri")] - { - use crate::extract::OriginalUri; - - if req.extensions().get::().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } - } - - let path = req.uri().path().to_owned(); - - match self.node.at(&path) { - Ok(match_) => { - match &self.fallback { - FallbackRoute::Default(_) => {} - FallbackRoute::Service(fallback) => { - req.extensions_mut() - .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); - } - } - - self.call_route(match_, req) - } - Err( - MatchError::NotFound - | MatchError::ExtraTrailingSlash - | MatchError::MissingTrailingSlash, - ) => match &mut self.fallback { - FallbackRoute::Default(fallback) => { - if let Some(super_fallback) = req.extensions_mut().remove::>() - { - let mut super_fallback = super_fallback.0.into_inner(); - super_fallback.call(req) - } else { - fallback.call(req) - } - } - FallbackRoute::Service(fallback) => fallback.call(req), - }, - } - } -} - -struct SuperFallback(SyncWrapper>); diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 923e10d124..8f070ad757 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -56,7 +56,7 @@ async fn fallback_accessing_state() { .fallback(|State(state): State<&'static str>| async move { state }) .with_state("state"); - let client = TestClient::from_service(app); + let client = TestClient::new(app); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/routing/tests/get_to_head.rs b/axum/src/routing/tests/get_to_head.rs index f0cd201c62..21888e6eb8 100644 --- a/axum/src/routing/tests/get_to_head.rs +++ b/axum/src/routing/tests/get_to_head.rs @@ -19,7 +19,6 @@ mod for_handlers { // don't use reqwest because it always strips bodies from HEAD responses let res = app - .into_service() .oneshot( Request::builder() .uri("/") @@ -55,7 +54,6 @@ mod for_services { // don't use reqwest because it always strips bodies from HEAD responses let res = app - .into_service() .oneshot( Request::builder() .uri("/") diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 43db151244..ea09501e2c 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -447,11 +447,11 @@ async fn middleware_still_run_for_unmatched_requests() { #[tokio::test] #[should_panic(expected = "\ - Invalid route: `Router::route_service` cannot be used with `RouterService`s. \ + Invalid route: `Router::route_service` cannot be used with `Router`s. \ Use `Router::nest` instead\ ")] async fn routing_to_router_panics() { - TestClient::new(Router::new().route_service("/", Router::new().into_service())); + TestClient::new(Router::new().route_service("/", Router::new())); } #[tokio::test] @@ -761,7 +761,7 @@ async fn extract_state() { }; let app = Router::new().route("/", get(handler)).with_state(state); - let client = TestClient::from_service(app); + let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -776,7 +776,7 @@ async fn explicitly_set_state() { ) .with_state("..."); - let client = TestClient::from_service(app); + let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index cddb38e548..0b60ee9007 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -1,6 +1,6 @@ #![allow(clippy::disallowed_names)] -use crate::{body::HttpBody, BoxError, Router}; +use crate::{body::HttpBody, BoxError}; mod test_client; pub(crate) use self::test_client::*; diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 296b8131fd..45a72b6cb1 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -1,4 +1,4 @@ -use super::{BoxError, HttpBody, Router}; +use super::{BoxError, HttpBody}; use bytes::Bytes; use http::{ header::{HeaderName, HeaderValue}, @@ -15,11 +15,7 @@ pub(crate) struct TestClient { } impl TestClient { - pub(crate) fn new(router: Router<(), Body>) -> Self { - Self::from_service(router.into_service()) - } - - pub(crate) fn from_service(svc: S) -> Self + pub(crate) fn new(svc: S) -> Self where S: Service, Response = http::Response> + Clone + Send + 'static, ResBody: HttpBody + Send + 'static, diff --git a/examples/handle-head-request/src/main.rs b/examples/handle-head-request/src/main.rs index 9624835cca..492d34251f 100644 --- a/examples/handle-head-request/src/main.rs +++ b/examples/handle-head-request/src/main.rs @@ -50,7 +50,7 @@ mod tests { #[tokio::test] async fn test_get() { - let app = app().into_service(); + let app = app(); let response = app .oneshot(Request::get("/get-head").body(Body::empty()).unwrap()) @@ -66,7 +66,7 @@ mod tests { #[tokio::test] async fn test_implicit_head() { - let app = app().into_service(); + let app = app(); let response = app .oneshot(Request::head("/get-head").body(Body::empty()).unwrap()) diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs index e57a73f430..bfb25199be 100644 --- a/examples/http-proxy/src/main.rs +++ b/examples/http-proxy/src/main.rs @@ -35,9 +35,7 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); - let router_svc = Router::new() - .route("/", get(|| async { "Hello, World!" })) - .into_service(); + let router_svc = Router::new().route("/", get(|| async { "Hello, World!" })); let service = tower::service_fn(move |req: Request| { let router_svc = router_svc.clone(); diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 334e0e254b..5d6c2bf7e7 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -26,7 +26,7 @@ use std::{ use tower::{BoxError, ServiceBuilder}; use tower_http::{ auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer, - trace::TraceLayer, ServiceBuilderExt, + trace::TraceLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/examples/query-params-with-empty-strings/src/main.rs b/examples/query-params-with-empty-strings/src/main.rs index d5f2ba2d6e..0af20111d7 100644 --- a/examples/query-params-with-empty-strings/src/main.rs +++ b/examples/query-params-with-empty-strings/src/main.rs @@ -104,7 +104,6 @@ mod tests { async fn send_request_get_body(query: &str) -> String { let body = app() - .into_service() .oneshot( Request::builder() .uri(format!("/?{}", query)) diff --git a/examples/rest-grpc-multiplex/src/main.rs b/examples/rest-grpc-multiplex/src/main.rs index 7a0ea3d61b..8376fcb0aa 100644 --- a/examples/rest-grpc-multiplex/src/main.rs +++ b/examples/rest-grpc-multiplex/src/main.rs @@ -55,7 +55,7 @@ async fn main() { .init(); // build the rest service - let rest = Router::new().route("/", get(web_root)).into_service(); + let rest = Router::new().route("/", get(web_root)); // build the grpc service let grpc = GreeterServer::new(GrpcServiceImpl::default()); diff --git a/examples/simple-router-wasm/src/main.rs b/examples/simple-router-wasm/src/main.rs index 5b1e7907b6..5f82109ed0 100644 --- a/examples/simple-router-wasm/src/main.rs +++ b/examples/simple-router-wasm/src/main.rs @@ -40,8 +40,7 @@ fn main() { #[allow(clippy::let_and_return)] async fn app(request: Request) -> Response { - let mut router = Router::new().route("/api/", get(index)).into_service(); - + let mut router = Router::new().route("/api/", get(index)); let response = router.call(request).await.unwrap(); response } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 2671df6d15..0bb9b352a0 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -61,7 +61,7 @@ mod tests { #[tokio::test] async fn hello_world() { - let app = app().into_service(); + let app = app(); // `Router` implements `tower::Service>` so we can // call it like any tower service, no need to run an HTTP server. @@ -78,7 +78,7 @@ mod tests { #[tokio::test] async fn json() { - let app = app().into_service(); + let app = app(); let response = app .oneshot( @@ -103,7 +103,7 @@ mod tests { #[tokio::test] async fn not_found() { - let app = app().into_service(); + let app = app(); let response = app .oneshot( @@ -154,7 +154,7 @@ mod tests { // in multiple request #[tokio::test] async fn multiple_request() { - let mut app = app().into_service(); + let mut app = app(); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap();