Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback inheritance for nested routers #1521

Merged
merged 4 commits into from Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion axum/CHANGELOG.md
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])

[#1521]: https://github.com/tokio-rs/axum/pull/1521

# 0.6.0-rc.4 (9. November, 2022)

Expand Down
39 changes: 23 additions & 16 deletions axum/src/docs/routing/nest.md
Expand Up @@ -90,8 +90,8 @@ let app = Router::new()

# Fallbacks

When nesting a router, if a request matches the prefix but the nested router doesn't have a matching
route, the outer fallback will _not_ be called:
If a nested router doesn't have its own fallback then it will inherit the
fallback from the outer router:

```rust
use axum::{routing::get, http::StatusCode, handler::Handler, Router};
Expand All @@ -100,38 +100,43 @@ async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

let api_routes = Router::new().nest_service("/users", get(|| async {}));
let api_routes = Router::new().route("/users", get(|| async {}));

let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback);
# let _: Router = app;
```

Here requests like `GET /api/not-found` will go into `api_routes` and then to
the fallback of `api_routes` which will return an empty `404 Not Found`
response. The outer fallback declared on `app` will _not_ be called.
Here requests like `GET /api/not-found` will go into `api_routes` but because
it doesn't have a matching route and doesn't have its own fallback it will call
the fallback from the outer router, i.e. the `fallback` function.

Think of nested services as swallowing requests that matches the prefix and
not falling back to outer router even if they don't have a matching route.

You can still add separate fallbacks to nested routers:
If the nested router has its own fallback then the outer fallback will not be
inherited:

```rust
use axum::{routing::get, http::StatusCode, handler::Handler, Json, Router};
use serde_json::{json, Value};
use axum::{
routing::get,
http::StatusCode,
handler::Handler,
Json,
Router,
};

async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}

async fn api_fallback() -> (StatusCode, Json<Value>) {
(StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found" })))
async fn api_fallback() -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "status": "Not Found" })),
)
}

let api_routes = Router::new()
.nest_service("/users", get(|| async {}))
// add dedicated fallback for requests starting with `/api`
.route("/users", get(|| async {}))
.fallback(api_fallback);

let app = Router::new()
Expand All @@ -140,6 +145,8 @@ let app = Router::new()
# let _: Router = app;
```

Here requests like `GET /api/not-found` will go to `api_fallback`.

# Panics

- If the route overlaps with another route. See [`Router::route`]
Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/matched_path.rs
Expand Up @@ -143,7 +143,7 @@ pub(crate) fn set_matched_path_for_request(

if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) {
extensions.insert(MatchedNestedPath(matched_path));
debug_assert!(matches!(dbg!(extensions.remove::<MatchedPath>()), None));
debug_assert!(matches!(extensions.remove::<MatchedPath>(), None));
} else {
extensions.insert(MatchedPath(matched_path));
extensions.remove::<MatchedNestedPath>();
Expand Down
6 changes: 3 additions & 3 deletions axum/src/routing/method_routing.rs
@@ -1,6 +1,6 @@
//! Route to services and handlers based on HTTP methods.

use super::IntoMakeService;
use super::{FallbackRoute, IntoMakeService};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
Expand Down Expand Up @@ -744,7 +744,7 @@ where
post: self.post.into_route(&state),
put: self.put.into_route(&state),
trace: self.trace.into_route(&state),
fallback: self.fallback.into_route(&state),
fallback: self.fallback.into_fallback_route(&state),
allow_header: self.allow_header,
}
}
Expand Down Expand Up @@ -1284,7 +1284,7 @@ pub struct WithState<B, E> {
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Route<B, E>,
fallback: FallbackRoute<B, E>,
allow_header: AllowHeader,
}

Expand Down
77 changes: 58 additions & 19 deletions axum/src/routing/mod.rs
Expand Up @@ -8,7 +8,7 @@ use crate::{
handler::{BoxedHandler, Handler},
util::try_downcast,
};
use axum_core::response::IntoResponse;
use axum_core::response::{IntoResponse, Response};
use http::Request;
use matchit::MatchError;
use std::{
Expand All @@ -18,7 +18,10 @@ use std::{
fmt,
sync::Arc,
};
use tower::{util::MapResponseLayer, ServiceBuilder};
use tower::{
util::{BoxCloneService, MapResponseLayer, Oneshot},
ServiceBuilder,
};
use tower_layer::Layer;
use tower_service::Service;

Expand Down Expand Up @@ -639,11 +642,29 @@ where
}
}

fn into_route(self, state: &S) -> Route<B, E> {
fn into_fallback_route(self, state: &S) -> FallbackRoute<B, E> {
match self {
Self::Default(route) => route,
Self::Service(route) => route,
Self::BoxedHandler(handler) => handler.into_route(state.clone()),
Self::Default(route) => FallbackRoute::Default(route),
Self::Service(route) => FallbackRoute::Service(route),
Self::BoxedHandler(handler) => {
FallbackRoute::Service(handler.into_route(state.clone()))
}
}
}

fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::Default(route) => Fallback::Default(f(route)),
Self::Service(route) => Fallback::Service(f(route)),
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
}
}
}
Expand All @@ -668,20 +689,38 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
}
}

impl<S, B, E> Fallback<S, B, E> {
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
/// Like `Fallback` but without the `S` param so it can be stored in `RouterService`
pub(crate) enum FallbackRoute<B, E = Infallible> {
Default(Route<B, E>),
Service(Route<B, E>),
}

impl<B, E> fmt::Debug for FallbackRoute<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
}
}
}

impl<B, E> Clone for FallbackRoute<B, E> {
fn clone(&self) -> Self {
match self {
Self::Default(inner) => Self::Default(inner.clone()),
Self::Service(inner) => Self::Service(inner.clone()),
}
}
}

impl<B, E> FallbackRoute<B, E> {
pub(crate) fn oneshot_inner(
&mut self,
req: Request<B>,
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
match self {
Self::Default(inner) => Fallback::Default(f(inner)),
Self::Service(inner) => Fallback::Service(f(inner)),
Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)),
FallbackRoute::Default(inner) => inner.oneshot_inner(req),
FallbackRoute::Service(inner) => inner.oneshot_inner(req),
}
}
}
Expand Down
36 changes: 31 additions & 5 deletions axum/src/routing/service.rs
@@ -1,4 +1,6 @@
use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router};
use super::{
future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router,
};
use crate::{
body::{Body, HttpBody},
response::Response,
Expand All @@ -11,14 +13,15 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
use tower::Service;

/// A [`Router`] converted into a [`Service`].
#[derive(Debug)]
pub struct RouterService<B = Body> {
routes: HashMap<RouteId, Route<B>>,
node: Arc<Node>,
fallback: Route<B>,
fallback: FallbackRoute<B>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of this still being a thing at the service level (and using request extensions to make things work). Can we not copy the fallback handler to nested routers that don't have a custom one when building the RouterService?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I don't think so. When calling nest we might not have the fallback yet

Router::new()
    // nested Router added, no fallback yet
    .nest(
        "/foo",
        Router::new(),
    )
    // now the fallback comes but the nested
    // router has already been turned into a RouterService
    .fallback(...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we could store nested routers (that want to inherit state) differently from nested services, same as we already do for method routers vs. services, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost 馃槄 we get into trouble when we have to apply middleware to such nested routers. When applying layers to such routers we have to call Router::layer, since they don't implement Service so we cannot apply the layer around the whole thing like we do today. But doing that subtly changes the behavior:

Router::new()
    .nest(
        "/foo",
        Router::new().route("/bar", ...),
    )
    .layer(log_url)

Here if /foo/bar is called log_url will see the URL with the /foo prefix, as it should because its applied around the whole router.

We would instead end up doing something equivalent to calling Router::layer on the inner router:

Router::new()
    .nest(
        "/foo",
        Router::new().route("/bar", ...).layer(log_url),
    )

But now log_url will see the URL with /foo removed, i.e. /bar.

We also cannot store the layer and apply it later, since it changes the request body and introduces new generic parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I'm wondering if we can pull something similar to what we're doing for boxed handlers 馃

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but failed 馃槥 Ran into some issues that I dunno how to resolve. See #1531

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaning towards merging this as is. We can always refactor things later.

This together with #1532 might be the final release candidate 馃

}

impl<B> RouterService<B>
Expand Down Expand Up @@ -52,7 +55,7 @@ where
Self {
routes,
node: router.node,
fallback: router.fallback.into_route(&state),
fallback: router.fallback.into_fallback_route(&state),
}
}

Expand Down Expand Up @@ -121,12 +124,35 @@ where
let path = req.uri().path().to_owned();

match self.node.at(&path) {
Ok(match_) => self.call_route(match_, req),
Ok(match_) => {
match &self.fallback {
FallbackRoute::Default(_) => {}
FallbackRoute::Service(fallback) => {
req.extensions_mut()
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
}
}

self.call_route(match_, req)
}
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => self.fallback.clone().call(req),
) => match &mut self.fallback {
FallbackRoute::Default(fallback) => {
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
{
let mut super_fallback = super_fallback.0.into_inner();
super_fallback.call(req)
} else {
fallback.call(req)
}
}
FallbackRoute::Service(fallback) => fallback.call(req),
},
}
}
}

struct SuperFallback<B>(SyncWrapper<Route<B>>);