Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow extracting MatchedPath in middleware for nested routes #1462

Merged
merged 10 commits into from Nov 8, 2022
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430])
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])
- **breaking:** `MatchedPath` can now no longer be extracted in middleware for
nested routes ([#1462])
- **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487])
- **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to
`FormRejection::FailedToDeserializeForm` ([#1496])
Expand All @@ -64,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
[#1462]: https://github.com/tokio-rs/axum/pull/1462
[#1487]: https://github.com/tokio-rs/axum/pull/1487
[#1496]: https://github.com/tokio-rs/axum/pull/1496

Expand Down
290 changes: 198 additions & 92 deletions axum/src/extract/matched_path.rs
@@ -1,7 +1,8 @@
use super::{rejection::*, FromRequestParts};
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use async_trait::async_trait;
use http::request::Parts;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

/// Access the path in the router that matches the request.
///
Expand All @@ -24,7 +25,10 @@ use std::sync::Arc;
/// # };
/// ```
///
/// # Accessing `MatchedPath` via extensions
///
/// `MatchedPath` can also be accessed from middleware via request extensions.
///
/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
/// create a span that contains the matched path:
///
Expand All @@ -49,10 +53,46 @@ use std::sync::Arc;
/// tracing::info_span!("http-request", %path)
/// }),
/// );
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// # let _: Router = app;
/// ```
///
/// # Matched path in nested routers
///
/// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes:
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
///
/// ```
/// use axum::{
/// Router,
/// RequestExt,
/// routing::get,
/// extract::{MatchedPath, rejection::MatchedPathRejection},
/// middleware::map_request,
/// http::Request,
/// body::Body,
/// };
///
/// async fn access_matched_path(mut request: Request<Body>) -> Request<Body> {
/// // if `/foo/bar` is called this will be `Err(_)` since that matches
/// // a nested route
/// let matched_path: Result<MatchedPath, MatchedPathRejection> =
/// request.extract_parts::<MatchedPath>().await;
///
/// request
/// }
///
/// // `MatchedPath` is always accessible on handlers added via `Router::route`
/// async fn handler(matched_path: MatchedPath) {}
///
/// let app = Router::new()
/// .nest(
/// "/foo",
/// Router::new().route("/bar", get(handler)),
/// )
/// .layer(map_request(access_matched_path));
/// # let _: Router = app;
/// ```
///
/// [nesting]: crate::Router::nest
#[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))]
#[derive(Clone, Debug)]
pub struct MatchedPath(pub(crate) Arc<str>);
Expand Down Expand Up @@ -82,130 +122,196 @@ where
}
}

#[derive(Clone, Debug)]
struct MatchedNestedPath(Arc<str>);

pub(crate) fn set_matched_path_for_request(
id: RouteId,
route_id_to_path: &HashMap<RouteId, Arc<str>>,
extensions: &mut http::Extensions,
) {
let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) {
matched_path
} else {
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
};

let matched_path = append_nested_matched_path(matched_path, extensions);

if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) {
extensions.insert(MatchedNestedPath(matched_path));
debug_assert!(matches!(dbg!(extensions.remove::<MatchedPath>()), None));
} else {
extensions.insert(MatchedPath(matched_path));
extensions.remove::<MatchedNestedPath>();
}
}

// a previous `MatchedPath` might exist if we're inside a nested Router
fn append_nested_matched_path(matched_path: &Arc<str>, extensions: &http::Extensions) -> Arc<str> {
if let Some(previous) = extensions
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str())
.or_else(|| Some(&extensions.get::<MatchedNestedPath>()?.0))
{
let previous = previous
.strip_suffix(NEST_TAIL_PARAM_CAPTURE)
.unwrap_or(previous);

let matched_path = format!("{previous}{matched_path}");
matched_path.into()
} else {
Arc::clone(matched_path)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
handler::HandlerWithoutStateExt, middleware::map_request, routing::get, test_helpers::*,
Router,
};
use http::{Request, StatusCode};
use std::task::{Context, Poll};
use tower::layer::layer_fn;
use tower_service::Service;

#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);

impl<S, B> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[tokio::test]
async fn extracting_on_handler() {
let app = Router::new().route(
"/:a",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
let client = TestClient::new(app);

fn call(&mut self, mut req: Request<B>) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.unwrap()
.as_str()
.to_owned();
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
self.0.call(req)
}
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:a");
}

#[derive(Clone)]
struct MatchedPathFromMiddleware(String);
#[tokio::test]
async fn extracting_on_handler_in_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
}

#[tokio::test]
async fn access_matched_path() {
let api = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
async fn extracting_on_handler_in_deeply_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().nest(
"/:b",
Router::new().route(
"/:c",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
),
);

async fn handler(
path: MatchedPath,
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
MatchedPathFromMiddleware,
>,
) -> String {
format!(
"extractor = {}, middleware = {}",
path.as_str(),
path_from_middleware
)
let client = TestClient::new(app);

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(res.text().await, "/:a/:b/:c");
}

#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware() {
async fn extract_matched_path<B>(
matched_path: Option<MatchedPath>,
req: Request<B>,
) -> Request<B> {
assert!(matched_path.is_none());
req
}

let app = Router::new()
.route(
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
)
.nest("/api", api)
.nest(
"/public",
Router::new()
.route("/assets/*path", get(handler))
// have to set the middleware here since otherwise the
// matched path is just `/public/*` since we're nesting
// this router
.layer(layer_fn(SetMatchedPathExtension)),
)
.nest_service("/foo", handler.into_service())
.layer(layer_fn(SetMatchedPathExtension));
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(extract_matched_path));

let client = TestClient::new(app);

let res = client.get("/api/users/123").send().await;
assert_eq!(res.text().await, "/api/users/:id");
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

// the router nested at `/public` doesn't handle `/`
let res = client.get("/public").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware_via_extension() {
async fn assert_no_matched_path<B>(req: Request<B>) -> Request<B> {
assert!(req.extensions().get::<MatchedPath>().is_none());
req
}

let res = client.get("/public/assets/css/style.css").send().await;
assert_eq!(
res.text().await,
"extractor = /public/assets/*path, middleware = /public/assets/*path"
);
let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(assert_no_matched_path));

let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "extractor = /foo, middleware = /foo");
let client = TestClient::new(app);

let res = client.get("/foo/").send().await;
assert_eq!(res.text().await, "extractor = /foo/, middleware = /foo/");
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(
res.text().await,
format!(
"extractor = /foo/*{}, middleware = /foo/*{}",
crate::routing::NEST_TAIL_PARAM,
crate::routing::NEST_TAIL_PARAM,
),
#[tokio::test]
async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let app = Router::new().nest(
"/:a",
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn nested_opaque_routers_append_to_matched_path() {
async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() {
async fn extract_matched_path<B>(req: Request<B>) -> Request<B> {
let matched_path = req.extensions().get::<MatchedPath>().unwrap();
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn extracting_on_nested_handler() {
async fn handler(path: Option<MatchedPath>) {
assert!(path.is_none());
}

let app = Router::new().nest_service("/:a", handler.into_service());

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Expand Up @@ -49,7 +49,7 @@ pub use crate::Extension;
pub use crate::form::Form;

#[cfg(feature = "matched-path")]
mod matched_path;
pub(crate) mod matched_path;

#[cfg(feature = "matched-path")]
#[doc(inline)]
Expand Down
4 changes: 2 additions & 2 deletions axum/src/routing/mod.rs
Expand Up @@ -48,7 +48,7 @@ pub use self::method_routing::{
};

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct RouteId(u32);
pub(crate) struct RouteId(u32);

impl RouteId {
fn next() -> Self {
Expand Down Expand Up @@ -110,7 +110,7 @@ where
}

pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";

impl<B> Router<(), B>
where
Expand Down