Skip to content

Commit

Permalink
Don't allow extracting MatchedPath in middleware for nested routes (#…
Browse files Browse the repository at this point in the history
…1462)

* Don't allow extracting `MatchedPath` for nested paths

* misc clean up

* Update docs

* changelog

* Apply suggestions from code review

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* Add test for nested handler service

* change to `debug_assert`

* apply suggestions from review

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
  • Loading branch information
davidpdrsn and jplatte committed Nov 8, 2022
1 parent e0ef641 commit 0e3f9d0
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 140 deletions.
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430])
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])
- **breaking:** `MatchedPath` can now no longer be extracted in middleware for
nested routes ([#1462])
- **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487])
- **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to
`FormRejection::FailedToDeserializeForm` ([#1496])
Expand All @@ -64,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
[#1462]: https://github.com/tokio-rs/axum/pull/1462
[#1487]: https://github.com/tokio-rs/axum/pull/1487
[#1496]: https://github.com/tokio-rs/axum/pull/1496

Expand Down
290 changes: 198 additions & 92 deletions axum/src/extract/matched_path.rs
@@ -1,7 +1,8 @@
use super::{rejection::*, FromRequestParts};
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use async_trait::async_trait;
use http::request::Parts;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

/// Access the path in the router that matches the request.
///
Expand All @@ -24,7 +25,10 @@ use std::sync::Arc;
/// # };
/// ```
///
/// # Accessing `MatchedPath` via extensions
///
/// `MatchedPath` can also be accessed from middleware via request extensions.
///
/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
/// create a span that contains the matched path:
///
Expand All @@ -49,10 +53,46 @@ use std::sync::Arc;
/// tracing::info_span!("http-request", %path)
/// }),
/// );
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// # let _: Router = app;
/// ```
///
/// # Matched path in nested routers
///
/// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes:
///
/// ```
/// use axum::{
/// Router,
/// RequestExt,
/// routing::get,
/// extract::{MatchedPath, rejection::MatchedPathRejection},
/// middleware::map_request,
/// http::Request,
/// body::Body,
/// };
///
/// async fn access_matched_path(mut request: Request<Body>) -> Request<Body> {
/// // if `/foo/bar` is called this will be `Err(_)` since that matches
/// // a nested route
/// let matched_path: Result<MatchedPath, MatchedPathRejection> =
/// request.extract_parts::<MatchedPath>().await;
///
/// request
/// }
///
/// // `MatchedPath` is always accessible on handlers added via `Router::route`
/// async fn handler(matched_path: MatchedPath) {}
///
/// let app = Router::new()
/// .nest(
/// "/foo",
/// Router::new().route("/bar", get(handler)),
/// )
/// .layer(map_request(access_matched_path));
/// # let _: Router = app;
/// ```
///
/// [nesting]: crate::Router::nest
#[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))]
#[derive(Clone, Debug)]
pub struct MatchedPath(pub(crate) Arc<str>);
Expand Down Expand Up @@ -82,130 +122,196 @@ where
}
}

#[derive(Clone, Debug)]
struct MatchedNestedPath(Arc<str>);

pub(crate) fn set_matched_path_for_request(
id: RouteId,
route_id_to_path: &HashMap<RouteId, Arc<str>>,
extensions: &mut http::Extensions,
) {
let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) {
matched_path
} else {
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
};

let matched_path = append_nested_matched_path(matched_path, extensions);

if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) {
extensions.insert(MatchedNestedPath(matched_path));
debug_assert!(matches!(dbg!(extensions.remove::<MatchedPath>()), None));
} else {
extensions.insert(MatchedPath(matched_path));
extensions.remove::<MatchedNestedPath>();
}
}

// a previous `MatchedPath` might exist if we're inside a nested Router
fn append_nested_matched_path(matched_path: &Arc<str>, extensions: &http::Extensions) -> Arc<str> {
if let Some(previous) = extensions
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str())
.or_else(|| Some(&extensions.get::<MatchedNestedPath>()?.0))
{
let previous = previous
.strip_suffix(NEST_TAIL_PARAM_CAPTURE)
.unwrap_or(previous);

let matched_path = format!("{previous}{matched_path}");
matched_path.into()
} else {
Arc::clone(matched_path)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
handler::HandlerWithoutStateExt, middleware::map_request, routing::get, test_helpers::*,
Router,
};
use http::{Request, StatusCode};
use std::task::{Context, Poll};
use tower::layer::layer_fn;
use tower_service::Service;

#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);

impl<S, B> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[tokio::test]
async fn extracting_on_handler() {
let app = Router::new().route(
"/:a",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
let client = TestClient::new(app);

fn call(&mut self, mut req: Request<B>) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.unwrap()
.as_str()
.to_owned();
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
self.0.call(req)
}
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:a");
}

#[derive(Clone)]
struct MatchedPathFromMiddleware(String);
#[tokio::test]
async fn extracting_on_handler_in_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
}

#[tokio::test]
async fn access_matched_path() {
let api = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
async fn extracting_on_handler_in_deeply_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().nest(
"/:b",
Router::new().route(
"/:c",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
),
);

async fn handler(
path: MatchedPath,
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
MatchedPathFromMiddleware,
>,
) -> String {
format!(
"extractor = {}, middleware = {}",
path.as_str(),
path_from_middleware
)
let client = TestClient::new(app);

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(res.text().await, "/:a/:b/:c");
}

#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware() {
async fn extract_matched_path<B>(
matched_path: Option<MatchedPath>,
req: Request<B>,
) -> Request<B> {
assert!(matched_path.is_none());
req
}

let app = Router::new()
.route(
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
)
.nest("/api", api)
.nest(
"/public",
Router::new()
.route("/assets/*path", get(handler))
// have to set the middleware here since otherwise the
// matched path is just `/public/*` since we're nesting
// this router
.layer(layer_fn(SetMatchedPathExtension)),
)
.nest_service("/foo", handler.into_service())
.layer(layer_fn(SetMatchedPathExtension));
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(extract_matched_path));

let client = TestClient::new(app);

let res = client.get("/api/users/123").send().await;
assert_eq!(res.text().await, "/api/users/:id");
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

// the router nested at `/public` doesn't handle `/`
let res = client.get("/public").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware_via_extension() {
async fn assert_no_matched_path<B>(req: Request<B>) -> Request<B> {
assert!(req.extensions().get::<MatchedPath>().is_none());
req
}

let res = client.get("/public/assets/css/style.css").send().await;
assert_eq!(
res.text().await,
"extractor = /public/assets/*path, middleware = /public/assets/*path"
);
let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(assert_no_matched_path));

let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "extractor = /foo, middleware = /foo");
let client = TestClient::new(app);

let res = client.get("/foo/").send().await;
assert_eq!(res.text().await, "extractor = /foo/, middleware = /foo/");
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(
res.text().await,
format!(
"extractor = /foo/*{}, middleware = /foo/*{}",
crate::routing::NEST_TAIL_PARAM,
crate::routing::NEST_TAIL_PARAM,
),
#[tokio::test]
async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let app = Router::new().nest(
"/:a",
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn nested_opaque_routers_append_to_matched_path() {
async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() {
async fn extract_matched_path<B>(req: Request<B>) -> Request<B> {
let matched_path = req.extensions().get::<MatchedPath>().unwrap();
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn extracting_on_nested_handler() {
async fn handler(path: Option<MatchedPath>) {
assert!(path.is_none());
}

let app = Router::new().nest_service("/:a", handler.into_service());

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Expand Up @@ -49,7 +49,7 @@ pub use crate::Extension;
pub use crate::form::Form;

#[cfg(feature = "matched-path")]
mod matched_path;
pub(crate) mod matched_path;

#[cfg(feature = "matched-path")]
#[doc(inline)]
Expand Down
4 changes: 2 additions & 2 deletions axum/src/routing/mod.rs
Expand Up @@ -48,7 +48,7 @@ pub use self::method_routing::{
};

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct RouteId(u32);
pub(crate) struct RouteId(u32);

impl RouteId {
fn next() -> Self {
Expand Down Expand Up @@ -110,7 +110,7 @@ where
}

pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";

impl<B> Router<(), B>
where
Expand Down

0 comments on commit 0e3f9d0

Please sign in to comment.