Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow extracting MatchedPath in middleware for nested routes #1462

Merged
merged 10 commits into from Nov 8, 2022
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -45,6 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
to be more inline with `tungstenite` ([#1421])
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])
- **breaking:** `MatchedPath` can now no longer be extracted in middleware for
nested routes ([#1462])

[#1352]: https://github.com/tokio-rs/axum/pull/1352
[#1368]: https://github.com/tokio-rs/axum/pull/1368
Expand All @@ -60,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
[#1462]: https://github.com/tokio-rs/axum/pull/1462

# 0.6.0-rc.2 (10. September, 2022)

Expand Down
280 changes: 186 additions & 94 deletions axum/src/extract/matched_path.rs
@@ -1,7 +1,8 @@
use super::{rejection::*, FromRequestParts};
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use async_trait::async_trait;
use http::request::Parts;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

/// Access the path in the router that matches the request.
///
Expand All @@ -24,7 +25,10 @@ use std::sync::Arc;
/// # };
/// ```
///
/// # Accessing `MatchedPath` via extensions
///
/// `MatchedPath` can also be accessed from middleware via request extensions.
///
/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
/// create a span that contains the matched path:
///
Expand All @@ -49,10 +53,47 @@ use std::sync::Arc;
/// tracing::info_span!("http-request", %path)
/// }),
/// );
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// # let _: Router = app;
/// ```
///
/// # Matched path in nested routers
///
/// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes:
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
///
/// ```
/// use axum::{
/// Router,
/// RequestExt,
/// routing::get,
/// extract::{MatchedPath, rejection::MatchedPathRejection},
/// middleware::map_request,
/// http::Request,
/// body::Body,
/// };
///
/// async fn access_matched_path(mut request: Request<Body>) -> Request<Body> {
/// // if `/foo/bar` is called this will be `Err(_)` since that matches
/// // a nested route
/// let matched_path: Result<MatchedPath, MatchedPathRejection> =
/// request.extract_parts::<MatchedPath>().await;
///
/// request
/// }
///
/// // `MatchedPath` is always accessible on handlers regardless
/// // if its for a nested route or not
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
/// async fn handler(matched_path: MatchedPath) {}
///
/// let app = Router::new()
/// .nest(
/// "/foo",
/// Router::new().route("/bar", get(handler)),
/// )
/// .layer(map_request(access_matched_path));
/// # let _: Router = app;
/// ```
///
/// [nesting]: crate::Router::nest
#[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))]
#[derive(Clone, Debug)]
pub struct MatchedPath(pub(crate) Arc<str>);
Expand Down Expand Up @@ -82,130 +123,181 @@ where
}
}

#[derive(Clone, Debug)]
struct MatchedNestedPath(Arc<str>);

pub(crate) fn set_matched_path_for_request(
id: RouteId,
route_id_to_path: &HashMap<RouteId, Arc<str>>,
extensions: &mut http::Extensions,
) {
let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) {
matched_path
} else {
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
};

let matched_path = append_nested_matched_path(matched_path, extensions);

if matched_path.contains(NEST_TAIL_PARAM_CAPTURE) {
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
extensions.insert(MatchedNestedPath(matched_path));
extensions.remove::<MatchedPath>();
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
} else {
extensions.insert(MatchedPath(matched_path));
extensions.remove::<MatchedNestedPath>();
}
}

// a previous `MatchedPath` might exist if we're inside a nested Router
fn append_nested_matched_path(matched_path: &Arc<str>, extensions: &http::Extensions) -> Arc<str> {
if let Some(previous) = extensions
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str())
.or_else(|| Some(&extensions.get::<MatchedNestedPath>()?.0))
{
let previous = if let Some(previous) = previous.strip_suffix(NEST_TAIL_PARAM_CAPTURE) {
previous
} else {
previous
};
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

let matched_path = format!("{}{}", previous, matched_path);
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
matched_path.into()
} else {
Arc::clone(matched_path)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
};
use crate::{middleware::map_request, routing::get, test_helpers::*, Router};
use http::{Request, StatusCode};
use std::task::{Context, Poll};
use tower::layer::layer_fn;
use tower_service::Service;

#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);

impl<S, B> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[tokio::test]
async fn extracting_on_handler() {
let app = Router::new().route(
"/:a",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
let client = TestClient::new(app);

fn call(&mut self, mut req: Request<B>) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.unwrap()
.as_str()
.to_owned();
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
self.0.call(req)
}
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:a");
}

#[derive(Clone)]
struct MatchedPathFromMiddleware(String);
#[tokio::test]
async fn extracting_on_handler_in_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
}

#[tokio::test]
async fn access_matched_path() {
let api = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
async fn extracting_on_handler_in_deeply_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().nest(
"/:b",
Router::new().route(
"/:c",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
),
);

async fn handler(
path: MatchedPath,
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
MatchedPathFromMiddleware,
>,
) -> String {
format!(
"extractor = {}, middleware = {}",
path.as_str(),
path_from_middleware
)
let client = TestClient::new(app);

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(res.text().await, "/:a/:b/:c");
}

#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware() {
async fn extract_matched_path<B>(
matched_path: Option<MatchedPath>,
req: Request<B>,
) -> Request<B> {
assert!(matched_path.is_none());
req
}

let app = Router::new()
.route(
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
)
.nest("/api", api)
.nest(
"/public",
Router::new()
.route("/assets/*path", get(handler))
// have to set the middleware here since otherwise the
// matched path is just `/public/*` since we're nesting
// this router
.layer(layer_fn(SetMatchedPathExtension)),
)
.nest_service("/foo", handler.into_service())
.layer(layer_fn(SetMatchedPathExtension));
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(extract_matched_path));

let client = TestClient::new(app);

let res = client.get("/api/users/123").send().await;
assert_eq!(res.text().await, "/api/users/:id");
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

// the router nested at `/public` doesn't handle `/`
let res = client.get("/public").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
#[tokio::test]
async fn cannot_extract_nested_matched_path_in_middleware_via_extension() {
async fn assert_no_matched_path<B>(req: Request<B>) -> Request<B> {
assert!(req.extensions().get::<MatchedPath>().is_none());
req
}

let res = client.get("/public/assets/css/style.css").send().await;
assert_eq!(
res.text().await,
"extractor = /public/assets/*path, middleware = /public/assets/*path"
);
let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(assert_no_matched_path));

let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "extractor = /foo, middleware = /foo");
let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

let res = client.get("/foo/").send().await;
assert_eq!(res.text().await, "extractor = /foo/, middleware = /foo/");
#[tokio::test]
async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let res = client.get("/foo/bar/baz").send().await;
assert_eq!(
res.text().await,
format!(
"extractor = /foo/*{}, middleware = /foo/*{}",
crate::routing::NEST_TAIL_PARAM,
crate::routing::NEST_TAIL_PARAM,
),
let app = Router::new().nest(
"/:a",
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn nested_opaque_routers_append_to_matched_path() {
async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() {
async fn extract_matched_path<B>(req: Request<B>) -> Request<B> {
let matched_path = req.extensions().get::<MatchedPath>().unwrap();
assert_eq!(matched_path.as_str(), "/:a/:b");
req
}

let app = Router::new().nest(
"/:a",
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
Router::new()
.route("/:b", get(|| async move {}))
.layer(map_request(extract_matched_path)),
);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
assert_eq!(res.status(), StatusCode::OK);
}
}
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Expand Up @@ -47,7 +47,7 @@ pub use crate::Extension;
pub use crate::form::Form;

#[cfg(feature = "matched-path")]
mod matched_path;
pub(crate) mod matched_path;

#[cfg(feature = "matched-path")]
#[doc(inline)]
Expand Down
4 changes: 2 additions & 2 deletions axum/src/routing/mod.rs
Expand Up @@ -48,7 +48,7 @@ pub use self::method_routing::{
};

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct RouteId(u32);
pub(crate) struct RouteId(u32);

impl RouteId {
fn next() -> Self {
Expand Down Expand Up @@ -107,7 +107,7 @@ where
}

pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";

impl<B> Router<(), B>
where
Expand Down