diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index fb73e8d1..89a1a865 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -26,13 +26,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **trace:** Correctly identify gRPC requests in default `on_response` callback ([#278]) - **cors:** Panic if a wildcard (`*`) is passed to `AllowOrigin::list`. Use `AllowOrigin::any()` instead ([#285]) +- **serve_dir:** Call the fallback on non-uft8 request paths ([#310]) +[#275]: https://github.com/tower-rs/tower-http/pull/275 [#278]: https://github.com/tower-rs/tower-http/pull/278 [#285]: https://github.com/tower-rs/tower-http/pull/285 -[#275]: https://github.com/tower-rs/tower-http/pull/275 [#289]: https://github.com/tower-rs/tower-http/pull/289 [#299]: https://github.com/tower-rs/tower-http/pull/299 [#303]: https://github.com/tower-rs/tower-http/pull/303 +[#310]: https://github.com/tower-rs/tower-http/pull/310 # 0.3.4 (June 06, 2022) diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index 5a0c1b3b..33573783 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -43,9 +43,11 @@ impl ResponseFuture { } } - pub(super) fn invalid_path() -> Self { + pub(super) fn invalid_path(fallback_and_request: Option<(F, Request)>) -> Self { Self { - inner: ResponseFutureInner::InvalidPath, + inner: ResponseFutureInner::InvalidPath { + fallback_and_request, + }, } } @@ -67,7 +69,9 @@ pin_project! { FallbackFuture { future: BoxFuture<'static, io::Result>>, }, - InvalidPath, + InvalidPath { + fallback_and_request: Option<(F, Request)>, + }, MethodNotAllowed, } } @@ -138,8 +142,14 @@ where break Pin::new(future).poll(cx) } - ResponseFutureInnerProj::InvalidPath => { - break Poll::Ready(Ok(not_found())); + ResponseFutureInnerProj::InvalidPath { + fallback_and_request, + } => { + if let Some((mut fallback, request)) = fallback_and_request.take() { + call_fallback(&mut fallback, request) + } else { + break Poll::Ready(Ok(not_found())); + } } ResponseFutureInnerProj::MethodNotAllowed => { diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index 4921a3f2..41c56ad3 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -299,16 +299,6 @@ where let extensions = std::mem::take(&mut parts.extensions); let req = Request::from_parts(parts, Empty::::new()); - let path_to_file = match self - .variant - .build_and_validate_path(&self.base, req.uri().path()) - { - Some(path_to_file) => path_to_file, - None => { - return ResponseFuture::invalid_path(); - } - }; - let fallback_and_request = self.fallback.as_mut().map(|fallback| { let mut fallback_req = Request::new(body); *fallback_req.method_mut() = req.method().clone(); @@ -323,6 +313,16 @@ where (fallback, fallback_req) }); + let path_to_file = match self + .variant + .build_and_validate_path(&self.base, req.uri().path()) + { + Some(path_to_file) => path_to_file, + None => { + return ResponseFuture::invalid_path(fallback_and_request); + } + }; + let buf_chunk_size = self.buf_chunk_size; let range_header = req .headers() diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 6388db93..27242013 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -7,8 +7,9 @@ use http::{header, Method, Response}; use http::{Request, StatusCode}; use http_body::Body as HttpBody; use hyper::Body; +use std::convert::Infallible; use std::io::{self, Read}; -use tower::ServiceExt; +use tower::{service_fn, ServiceExt}; #[tokio::test] async fn basic() { @@ -688,3 +689,25 @@ async fn with_fallback_svc_and_not_append_index_html_on_directories() { let body = body_into_text(res.into_body()).await; assert_eq!(body, "from fallback /"); } + +// https://github.com/tower-rs/tower-http/issues/308 +#[tokio::test] +async fn calls_fallback_on_invalid_paths() { + async fn fallback(_: T) -> Result, std::io::Error> { + let mut res = Response::new(Body::empty()); + res.headers_mut() + .insert("from-fallback", "1".parse().unwrap()); + Ok(res) + } + + let svc = ServeDir::new("..").fallback(service_fn(fallback)); + + let req = Request::builder() + .uri("/weird_%c3%28_path") + .body(Body::empty()) + .unwrap(); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.headers()["from-fallback"], "1"); +}