Skip to content

Commit

Permalink
Add fallback inheritance for nested routers (#1521)
Browse files Browse the repository at this point in the history
* fallback inheritance

* cleanup

* changelog
  • Loading branch information
davidpdrsn committed Nov 18, 2022
1 parent 2e8a7e5 commit 7090649
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 60 deletions.
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -7,8 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])

[#1521]: https://github.com/tokio-rs/axum/pull/1521

# 0.6.0-rc.4 (9. November, 2022)

- **changed**: The inner error of a `JsonRejection` is now
Expand Down
39 changes: 23 additions & 16 deletions axum/src/docs/routing/nest.md
Expand Up @@ -90,8 +90,8 @@ let app = Router::new()

# Fallbacks

When nesting a router, if a request matches the prefix but the nested router doesn't have a matching
route, the outer fallback will _not_ be called:
If a nested router doesn't have its own fallback then it will inherit the
fallback from the outer router:

```rust
use axum::{routing::get, http::StatusCode, handler::Handler, Router};
Expand All @@ -100,38 +100,43 @@ async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

let api_routes = Router::new().nest_service("/users", get(|| async {}));
let api_routes = Router::new().route("/users", get(|| async {}));

let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback);
# let _: Router = app;
```

Here requests like `GET /api/not-found` will go into `api_routes` and then to
the fallback of `api_routes` which will return an empty `404 Not Found`
response. The outer fallback declared on `app` will _not_ be called.
Here requests like `GET /api/not-found` will go into `api_routes` but because
it doesn't have a matching route and doesn't have its own fallback it will call
the fallback from the outer router, i.e. the `fallback` function.

Think of nested services as swallowing requests that matches the prefix and
not falling back to outer router even if they don't have a matching route.

You can still add separate fallbacks to nested routers:
If the nested router has its own fallback then the outer fallback will not be
inherited:

```rust
use axum::{routing::get, http::StatusCode, handler::Handler, Json, Router};
use serde_json::{json, Value};
use axum::{
routing::get,
http::StatusCode,
handler::Handler,
Json,
Router,
};

async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

async fn api_fallback() -> (StatusCode, Json<Value>) {
(StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found" })))
async fn api_fallback() -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "status": "Not Found" })),
)
}

let api_routes = Router::new()
.nest_service("/users", get(|| async {}))
// add dedicated fallback for requests starting with `/api`
.route("/users", get(|| async {}))
.fallback(api_fallback);

let app = Router::new()
Expand All @@ -140,6 +145,8 @@ let app = Router::new()
# let _: Router = app;
```

Here requests like `GET /api/not-found` will go to `api_fallback`.

# Panics

- If the route overlaps with another route. See [`Router::route`]
Expand Down
6 changes: 3 additions & 3 deletions axum/src/routing/method_routing.rs
@@ -1,6 +1,6 @@
//! Route to services and handlers based on HTTP methods.

use super::IntoMakeService;
use super::{FallbackRoute, IntoMakeService};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
Expand Down Expand Up @@ -744,7 +744,7 @@ where
post: self.post.into_route(&state),
put: self.put.into_route(&state),
trace: self.trace.into_route(&state),
fallback: self.fallback.into_route(&state),
fallback: self.fallback.into_fallback_route(&state),
allow_header: self.allow_header,
}
}
Expand Down Expand Up @@ -1284,7 +1284,7 @@ pub struct WithState<B, E> {
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Route<B, E>,
fallback: FallbackRoute<B, E>,
allow_header: AllowHeader,
}

Expand Down
77 changes: 58 additions & 19 deletions axum/src/routing/mod.rs
Expand Up @@ -8,7 +8,7 @@ use crate::{
handler::{BoxedHandler, Handler},
util::try_downcast,
};
use axum_core::response::IntoResponse;
use axum_core::response::{IntoResponse, Response};
use http::Request;
use matchit::MatchError;
use std::{
Expand All @@ -18,7 +18,10 @@ use std::{
fmt,
sync::Arc,
};
use tower::{util::MapResponseLayer, ServiceBuilder};
use tower::{
util::{BoxCloneService, MapResponseLayer, Oneshot},
ServiceBuilder,
};
use tower_layer::Layer;
use tower_service::Service;

Expand Down Expand Up @@ -639,11 +642,29 @@ where
}
}

fn into_route(self, state: &S) -> Route<B, E> {
fn into_fallback_route(self, state: &S) -> FallbackRoute<B, E> {
match self {
Self::Default(route) => route,
Self::Service(route) => route,
Self::BoxedHandler(handler) => handler.into_route(state.clone()),
Self::Default(route) => FallbackRoute::Default(route),
Self::Service(route) => FallbackRoute::Service(route),
Self::BoxedHandler(handler) => {
FallbackRoute::Service(handler.into_route(state.clone()))
}
}
}

fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::Default(route) => Fallback::Default(f(route)),
Self::Service(route) => Fallback::Service(f(route)),
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
}
}
}
Expand All @@ -668,20 +689,38 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
}
}

impl<S, B, E> Fallback<S, B, E> {
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
/// Like `Fallback` but without the `S` param so it can be stored in `RouterService`
pub(crate) enum FallbackRoute<B, E = Infallible> {
Default(Route<B, E>),
Service(Route<B, E>),
}

impl<B, E> fmt::Debug for FallbackRoute<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
}
}
}

impl<B, E> Clone for FallbackRoute<B, E> {
fn clone(&self) -> Self {
match self {
Self::Default(inner) => Self::Default(inner.clone()),
Self::Service(inner) => Self::Service(inner.clone()),
}
}
}

impl<B, E> FallbackRoute<B, E> {
pub(crate) fn oneshot_inner(
&mut self,
req: Request<B>,
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
match self {
Self::Default(inner) => Fallback::Default(f(inner)),
Self::Service(inner) => Fallback::Service(f(inner)),
Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)),
FallbackRoute::Default(inner) => inner.oneshot_inner(req),
FallbackRoute::Service(inner) => inner.oneshot_inner(req),
}
}
}
Expand Down
36 changes: 31 additions & 5 deletions axum/src/routing/service.rs
@@ -1,4 +1,6 @@
use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router};
use super::{
future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router,
};
use crate::{
body::{Body, HttpBody},
response::Response,
Expand All @@ -11,14 +13,15 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
use tower::Service;

/// A [`Router`] converted into a [`Service`].
#[derive(Debug)]
pub struct RouterService<B = Body> {
routes: HashMap<RouteId, Route<B>>,
node: Arc<Node>,
fallback: Route<B>,
fallback: FallbackRoute<B>,
}

impl<B> RouterService<B>
Expand Down Expand Up @@ -52,7 +55,7 @@ where
Self {
routes,
node: router.node,
fallback: router.fallback.into_route(&state),
fallback: router.fallback.into_fallback_route(&state),
}
}

Expand Down Expand Up @@ -121,12 +124,35 @@ where
let path = req.uri().path().to_owned();

match self.node.at(&path) {
Ok(match_) => self.call_route(match_, req),
Ok(match_) => {
match &self.fallback {
FallbackRoute::Default(_) => {}
FallbackRoute::Service(fallback) => {
req.extensions_mut()
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
}
}

self.call_route(match_, req)
}
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => self.fallback.clone().call(req),
) => match &mut self.fallback {
FallbackRoute::Default(fallback) => {
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
{
let mut super_fallback = super_fallback.0.into_inner();
super_fallback.call(req)
} else {
fallback.call(req)
}
}
FallbackRoute::Service(fallback) => fallback.call(req),
},
}
}
}

struct SuperFallback<B>(SyncWrapper<Route<B>>);

0 comments on commit 7090649

Please sign in to comment.