diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 613b69c0..bf90c8c7 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added +- Add `ServeDir::{fallback, not_found_service}` for calling another service if + the file cannot be found. - Add `SetStatus` to override status codes. - **cors**: Added `CorsLayer::very_permissive` which is like `CorsLayer::permissive` except it (truly) allows credentials. This is made diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index b48e3067..d153d18f 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -79,7 +79,7 @@ auth = ["base64"] catch-panic = ["tracing", "futures-util/std"] cors = [] follow-redirect = ["iri-string", "tower/util"] -fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate"] +fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status"] map-request-body = [] map-response-body = [] metrics = ["tokio/time"] diff --git a/tower-http/src/services/fs/mod.rs b/tower-http/src/services/fs/mod.rs index f5aca03d..b05526b4 100644 --- a/tower-http/src/services/fs/mod.rs +++ b/tower-http/src/services/fs/mod.rs @@ -2,8 +2,8 @@ use bytes::Bytes; use futures_util::Stream; -use http::{HeaderMap, Response, StatusCode}; -use http_body::{combinators::BoxBody, Body, Empty}; +use http::{HeaderMap, StatusCode}; +use http_body::Body; use httpdate::HttpDate; use pin_project_lite::pin_project; use std::fs::Metadata; @@ -28,6 +28,7 @@ use crate::content_encoding::{Encoding, QValue, SupportedEncodings}; pub use self::{ serve_dir::{ + DefaultServeDirFallback, // The response body and future are used for both ServeDir and ServeFile ResponseBody as ServeFileSystemResponseBody, ResponseFuture as ServeFileSystemResponseFuture, @@ -189,22 +190,6 @@ where } } -fn response_from_io_error( - err: io::Error, -) -> Result>, io::Error> { - match err.kind() { - io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied => { - let res = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Empty::new().map_err(|err| match err {}).boxed()) - .unwrap(); - - Ok(res) - } - _ => Err(err), - } -} - struct LastModified(HttpDate); impl From for LastModified { diff --git a/tower-http/src/services/fs/serve_dir.rs b/tower-http/src/services/fs/serve_dir.rs index 0e6a31ca..e4221a9a 100644 --- a/tower-http/src/services/fs/serve_dir.rs +++ b/tower-http/src/services/fs/serve_dir.rs @@ -2,31 +2,32 @@ use super::{ check_modified_headers, open_file_with_fallback, AsyncReadBody, IfModifiedSince, IfUnmodifiedSince, LastModified, PrecompressedVariants, }; -use crate::services::fs::file_metadata_with_fallback; use crate::{ content_encoding::{encodings, Encoding}, - services::fs::DEFAULT_CAPACITY, + services::fs::{file_metadata_with_fallback, DEFAULT_CAPACITY}, + set_status::SetStatus, + BoxError, }; use bytes::Bytes; use futures_util::ready; use http::response::Builder; use http::{header, HeaderValue, Method, Request, Response, StatusCode, Uri}; -use http_body::{combinators::BoxBody, Body, Empty, Full}; +use http_body::{combinators::UnsyncBoxBody, Body, Empty, Full}; use http_range_header::RangeUnsatisfiableError; use percent_encoding::percent_decode; -use std::fs::Metadata; -use std::io::SeekFrom; -use std::ops::RangeInclusive; -use std::path::Component; +use pin_project_lite::pin_project; use std::{ + convert::Infallible, + fs::Metadata, future::Future, io, - path::{Path, PathBuf}, + io::SeekFrom, + ops::RangeInclusive, + path::{Component, Path, PathBuf}, pin::Pin, task::{Context, Poll}, }; -use tokio::fs::File; -use tokio::io::AsyncSeekExt; +use tokio::{fs::File, io::AsyncSeekExt}; use tower_service::Service; /// Service that serves files from a given directory and all its sub directories. @@ -58,53 +59,15 @@ use tower_service::Service; /// .expect("server error"); /// # }; /// ``` -/// -/// # Handling files not found -/// -/// By default `ServeDir` will return an empty `404 Not Found` response if there -/// is no file at the requested path. That can be customized by using -/// [`and_then`](tower::ServiceBuilder::and_then) to change the response: -/// -/// ``` -/// use tower_http::services::fs::{ServeDir, ServeFileSystemResponseBody}; -/// use tower::ServiceBuilder; -/// use http::{StatusCode, Response}; -/// use http_body::{Body as _, Full}; -/// use std::io; -/// -/// let service = ServiceBuilder::new() -/// .and_then(|response: Response| async move { -/// let response = if response.status() == StatusCode::NOT_FOUND { -/// let body = Full::from("Not Found") -/// .map_err(|err| match err {}) -/// .boxed(); -/// Response::builder() -/// .status(StatusCode::NOT_FOUND) -/// .body(body) -/// .unwrap() -/// } else { -/// response.map(|body| body.boxed()) -/// }; -/// -/// Ok::<_, io::Error>(response) -/// }) -/// .service(ServeDir::new("assets")); -/// # async { -/// # let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); -/// # hyper::Server::bind(&addr) -/// # .serve(tower::make::Shared::new(service)) -/// # .await -/// # .expect("server error"); -/// # }; -/// ``` #[derive(Clone, Debug)] -pub struct ServeDir { +pub struct ServeDir { base: PathBuf, buf_chunk_size: usize, precompressed_variants: Option, // This is used to specialise implementation for // single files variant: ServeVariant, + fallback: Option, } // Allow the ServeDir service to be used in the ServeFile service @@ -120,12 +83,36 @@ enum ServeVariant { } impl ServeVariant { - fn full_path(&self, base_path: &Path, requested_path: &str) -> Option { + fn build_and_validate_path(&self, base_path: &Path, requested_path: &str) -> Option { match self { ServeVariant::Directory { append_index_html_on_directories: _, } => { - let full_path = build_and_validate_path(base_path, requested_path)?; + let path = requested_path.trim_start_matches('/'); + + let path_decoded = percent_decode(path.as_ref()).decode_utf8().ok()?; + let path_decoded = Path::new(&*path_decoded); + + let mut full_path = base_path.to_path_buf(); + for component in path_decoded.components() { + match component { + Component::Normal(comp) => { + // protect against paths like `/foo/c:/bar/baz` (#204) + if Path::new(&comp) + .components() + .all(|c| matches!(c, Component::Normal(_))) + { + full_path.push(comp) + } else { + return None; + } + } + Component::CurDir => {} + Component::Prefix(_) | Component::RootDir | Component::ParentDir => { + return None; + } + } + } Some(full_path) } ServeVariant::SingleFile { mime: _ } => Some(base_path.to_path_buf()), @@ -133,34 +120,7 @@ impl ServeVariant { } } -fn build_and_validate_path(base_path: &Path, requested_path: &str) -> Option { - let path = requested_path.trim_start_matches('/'); - - let path_decoded = percent_decode(path.as_ref()).decode_utf8().ok()?; - let path_decoded = Path::new(&*path_decoded); - - let mut full_path = base_path.to_path_buf(); - for component in path_decoded.components() { - match component { - Component::Normal(comp) => { - // protect against paths like `/foo/c:/bar/baz` (#204) - if Path::new(&comp) - .components() - .all(|c| matches!(c, Component::Normal(_))) - { - full_path.push(comp) - } else { - return None; - } - } - Component::CurDir => {} - Component::Prefix(_) | Component::RootDir | Component::ParentDir => return None, - } - } - Some(full_path) -} - -impl ServeDir { +impl ServeDir { /// Create a new [`ServeDir`]. pub fn new>(path: P) -> Self { let mut base = PathBuf::from("."); @@ -173,6 +133,7 @@ impl ServeDir { variant: ServeVariant::Directory { append_index_html_on_directories: true, }, + fallback: None, } } @@ -182,9 +143,12 @@ impl ServeDir { buf_chunk_size: DEFAULT_CAPACITY, precompressed_variants: None, variant: ServeVariant::SingleFile { mime }, + fallback: None, } } +} +impl ServeDir { /// If the requested path is a directory append `index.html`. /// /// This is useful for static sites. @@ -260,6 +224,73 @@ impl ServeDir { .deflate = true; self } + + /// Set the fallback service. + /// + /// This service will be called if there is no file at the path of the request. + /// + /// The status code returned by the fallback will not be altered. Use + /// [`ServeDir::not_found_service`] to set a fallback and always respond with `404 Not Found`. + /// + /// # Example + /// + /// This can be used to respond with a different file: + /// + /// ```rust + /// use tower_http::services::{ServeDir, ServeFile}; + /// + /// let service = ServeDir::new("assets") + /// // respond with `not_found.html` for missing files + /// .fallback(ServeFile::new("assets/not_found.html")); + /// + /// # async { + /// // Run our service using `hyper` + /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); + /// hyper::Server::bind(&addr) + /// .serve(tower::make::Shared::new(service)) + /// .await + /// .expect("server error"); + /// # }; + /// ``` + pub fn fallback(self, new_fallback: F2) -> ServeDir { + ServeDir { + base: self.base, + buf_chunk_size: self.buf_chunk_size, + precompressed_variants: self.precompressed_variants, + variant: self.variant, + fallback: Some(new_fallback), + } + } + + /// Set the fallback service and override the fallback's status code to `404 Not Found`. + /// + /// This service will be called if there is no file at the path of the request. + /// + /// # Example + /// + /// This can be used to respond with a different file: + /// + /// ```rust + /// use tower_http::services::{ServeDir, ServeFile}; + /// + /// let service = ServeDir::new("assets") + /// // respond with `404 Not Found` and the contents of `not_found.html` for missing files + /// .not_found_service(ServeFile::new("assets/not_found.html")); + /// + /// # async { + /// // Run our service using `hyper` + /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); + /// hyper::Server::bind(&addr) + /// .serve(tower::make::Shared::new(service)) + /// .await + /// .expect("server error"); + /// # }; + /// ``` + /// + /// Setups like this are often found in single page applications. + pub fn not_found_service(self, new_fallback: F2) -> ServeDir> { + self.fallback(SetStatus::new(new_fallback, StatusCode::NOT_FOUND)) + } } async fn maybe_redirect_or_append_path( @@ -285,26 +316,63 @@ async fn maybe_redirect_or_append_path( None } -impl Service> for ServeDir { +impl Service> for ServeDir +where + F: Service, Response = Response> + Clone, + F::Error: Into, + F::Future: Send + 'static, + FResBody: http_body::Body + Send + 'static, + FResBody::Error: Into>, +{ type Response = Response; type Error = io::Error; - type Future = ResponseFuture; + type Future = ResponseFuture; #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(fallback) = &mut self.fallback { + fallback.poll_ready(cx).map_err(Into::into) + } else { + Poll::Ready(Ok(())) + } } fn call(&mut self, req: Request) -> Self::Future { - let mut full_path = match self.variant.full_path(&self.base, req.uri().path()) { + // `ServeDir` doesn't care about the request body but the fallback might. So move out the + // body and pass it to the fallback, leaving an empty body in its place + // + // this is necessary because we cannot clone bodies + let (mut parts, body) = req.into_parts(); + // same goes for extensions + let extensions = std::mem::take(&mut parts.extensions); + let req = Request::from_parts(parts, Empty::::new()); + + let mut full_path = match self + .variant + .build_and_validate_path(&self.base, req.uri().path()) + { Some(full_path) => full_path, None => { return ResponseFuture { - inner: Inner::Invalid, - } + inner: ResponseFutureInner::InvalidPath, + }; } }; + let fallback_and_request = self.fallback.as_mut().map(|fallback| { + let mut req = Request::new(body); + *req.method_mut() = req.method().clone(); + *req.uri_mut() = req.uri().clone(); + *req.headers_mut() = req.headers().clone(); + *req.extensions_mut() = extensions; + + // get the ready fallback and leave a non-ready clone in its place + let clone = fallback.clone(); + let fallback = std::mem::replace(fallback, clone); + + (fallback, req) + }); + let buf_chunk_size = self.buf_chunk_size; let uri = req.uri().clone(); let range_header = req @@ -351,7 +419,7 @@ impl Service> for ServeDir { let guess = mime_guess::from_path(&full_path); guess .first_raw() - .map(|mime| HeaderValue::from_static(mime)) + .map(HeaderValue::from_static) .unwrap_or_else(|| { HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() }) @@ -416,7 +484,10 @@ impl Service> for ServeDir { }); ResponseFuture { - inner: Inner::Valid(open_file_future), + inner: ResponseFutureInner::OpenFileFuture { + future: open_file_future, + fallback_and_request, + }, } } } @@ -466,6 +537,7 @@ fn append_slash_on_path(uri: Uri) -> Uri { builder.build().unwrap() } +#[allow(clippy::large_enum_variant)] enum Output { File(FileRequest), Redirect(HeaderValue), @@ -486,25 +558,50 @@ enum FileRequestExtent { Head(Metadata), } -type BoxFuture = Pin + Send + Sync + 'static>>; +type BoxFuture = Pin + Send + 'static>>; -enum Inner { - Valid(BoxFuture>), - Invalid, +pin_project! { + /// Response future of [`ServeDir`]. + pub struct ResponseFuture { + #[pin] + inner: ResponseFutureInner, + } } -/// Response future of [`ServeDir`]. -pub struct ResponseFuture { - inner: Inner, +pin_project! { + #[project = ResponseFutureInnerProj] + enum ResponseFutureInner { + OpenFileFuture { + #[pin] + future: BoxFuture>, + fallback_and_request: Option<(F, Request)>, + }, + FallbackFuture { + future: BoxFuture>>, + }, + InvalidPath, + } } -impl Future for ResponseFuture { +impl Future for ResponseFuture +where + F: Service, Response = Response> + Clone, + F::Error: Into, + F::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into>, +{ type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match &mut self.inner { - Inner::Valid(open_file_future) => { - return match ready!(Pin::new(open_file_future).poll(cx)) { + loop { + let mut this = self.as_mut().project(); + + let new_state = match this.inner.as_mut().project() { + ResponseFutureInnerProj::OpenFileFuture { + future: open_file_future, + fallback_and_request, + } => match ready!(open_file_future.poll(cx)) { Ok(Output::File(file_request)) => { let (maybe_file, size) = match file_request.extent { FileRequestExtent::Full(file, meta) => (Some(file), meta.len()), @@ -529,7 +626,7 @@ impl Future for ResponseFuture { file_request.chunk_size, size, ); - Poll::Ready(Ok(res.unwrap())) + return Poll::Ready(Ok(res.unwrap())); } Ok(Output::Redirect(location)) => { @@ -538,28 +635,37 @@ impl Future for ResponseFuture { .status(StatusCode::TEMPORARY_REDIRECT) .body(empty_body()) .unwrap(); - Poll::Ready(Ok(res)) + return Poll::Ready(Ok(res)); } Ok(Output::StatusCode(code)) => { let res = Response::builder().status(code).body(empty_body()).unwrap(); - Poll::Ready(Ok(res)) + return Poll::Ready(Ok(res)); } - Err(err) => Poll::Ready( - super::response_from_io_error(err).map(|res| res.map(ResponseBody::new)), - ), - }; - } - Inner::Invalid => { - let res = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(empty_body()) - .unwrap(); + Err(err) => match err.kind() { + io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied => { + if let Some((mut fallback, request)) = fallback_and_request.take() { + call_fallback(&mut fallback, request) + } else { + return Poll::Ready(not_found()); + } + } + _ => return Poll::Ready(Err(err)), + }, + }, - Poll::Ready(Ok(res)) - } + ResponseFutureInnerProj::FallbackFuture { future } => { + return Pin::new(future).poll(cx) + } + + ResponseFutureInnerProj::InvalidPath => { + return Poll::Ready(not_found()); + } + }; + + this.inner.set(new_state); } } } @@ -586,7 +692,7 @@ fn handle_file_request( let body = if let Some(file) = maybe_file { let body = AsyncReadBody::with_capacity_limited(file, chunk_size, range_size) - .boxed(); + .boxed_unsync(); ResponseBody::new(body) } else { empty_body() @@ -616,7 +722,7 @@ fn handle_file_request( // Not a range request None => { let body = if let Some(file) = maybe_file { - let box_body = AsyncReadBody::with_capacity(file, chunk_size).boxed(); + let box_body = AsyncReadBody::with_capacity(file, chunk_size).boxed_unsync(); ResponseBody::new(box_body) } else { empty_body() @@ -629,31 +735,86 @@ fn handle_file_request( } fn empty_body() -> ResponseBody { - let body = Empty::new().map_err(|err| match err {}).boxed(); + let body = Empty::new().map_err(|err| match err {}).boxed_unsync(); ResponseBody::new(body) } fn body_from_bytes(bytes: Bytes) -> ResponseBody { - let body = Full::from(bytes).map_err(|err| match err {}).boxed(); + let body = Full::from(bytes).map_err(|err| match err {}).boxed_unsync(); ResponseBody::new(body) } opaque_body! { /// Response body for [`ServeDir`] and [`ServeFile`]. - pub type ResponseBody = BoxBody; + pub type ResponseBody = UnsyncBoxBody; +} + +/// The default fallback service used with [`ServeDir`]. +#[derive(Debug, Clone, Copy)] +pub struct DefaultServeDirFallback(Infallible); + +impl Service> for DefaultServeDirFallback +where + ReqBody: Send + 'static, +{ + type Response = Response; + type Error = io::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + match self.0 {} + } + + fn call(&mut self, _req: Request) -> Self::Future { + match self.0 {} + } +} + +fn not_found() -> io::Result> { + let res = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(empty_body()) + .unwrap(); + Ok(res) +} + +fn call_fallback(fallback: &mut F, req: Request) -> ResponseFutureInner +where + F: Service, Response = Response> + Clone, + F::Error: Into, + F::Future: Send + 'static, + FResBody: http_body::Body + Send + 'static, + FResBody::Error: Into, +{ + let future = fallback.call(req); + let future = async move { + let response = future.await.map_err(Into::into)?; + let response = response + .map(|body| { + body.map_err(|err| match err.into().downcast::() { + Ok(err) => *err, + Err(err) => io::Error::new(io::ErrorKind::Other, err), + }) + .boxed_unsync() + }) + .map(ResponseBody::new); + Ok(response) + }; + let future = Box::pin(future); + ResponseFutureInner::FallbackFuture { future } } #[cfg(test)] mod tests { - use std::io::Read; + use crate::services::ServeFile; - #[allow(unused_imports)] use super::*; use brotli::BrotliDecompress; use flate2::bufread::{DeflateDecoder, GzDecoder}; use http::{Request, StatusCode}; use http_body::Body as HttpBody; use hyper::Body; + use std::io::Read; use tower::ServiceExt; #[tokio::test] @@ -1229,4 +1390,43 @@ mod tests { let body = res.into_body().data().await; assert!(body.is_none()); } + + #[tokio::test] + async fn with_fallback_svc() { + async fn fallback(_: Request) -> io::Result> { + Ok(Response::new(Body::from("from fallback"))) + } + + let svc = ServeDir::new("..").fallback(tower::service_fn(fallback)); + + let req = Request::builder() + .uri("/doesnt-exist") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let body = body_into_text(res.into_body()).await; + assert_eq!(body, "from fallback"); + } + + #[tokio::test] + async fn with_fallback_serve_file() { + let svc = ServeDir::new("..").fallback(ServeFile::new("../README.md")); + + let req = Request::builder() + .uri("/doesnt-exist") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["content-type"], "text/markdown"); + + let body = body_into_text(res.into_body()).await; + + let contents = std::fs::read_to_string("../README.md").unwrap(); + assert_eq!(body, contents); + } } diff --git a/tower-http/src/services/fs/serve_file.rs b/tower-http/src/services/fs/serve_file.rs index 08265615..45329597 100644 --- a/tower-http/src/services/fs/serve_file.rs +++ b/tower-http/src/services/fs/serve_file.rs @@ -12,6 +12,7 @@ use tower_service::Service; /// Service that serves a file. #[derive(Clone, Debug)] pub struct ServeFile(ServeDir); + // Note that this is just a special case of ServeDir impl ServeFile { /// Create a new [`ServeFile`]. @@ -21,7 +22,7 @@ impl ServeFile { let guess = mime_guess::from_path(path.as_ref()); let mime = guess .first_raw() - .map(|mime| HeaderValue::from_static(mime)) + .map(HeaderValue::from_static) .unwrap_or_else(|| { HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() }); @@ -91,7 +92,10 @@ impl ServeFile { } } -impl Service> for ServeFile { +impl Service> for ServeFile +where + ReqBody: Send + 'static, +{ type Error = >>::Error; type Response = >>::Response; type Future = >>::Future; @@ -101,6 +105,7 @@ impl Service> for ServeFile { Poll::Ready(Ok(())) } + #[inline] fn call(&mut self, req: Request) -> Self::Future { self.0.call(req) } @@ -108,10 +113,6 @@ impl Service> for ServeFile { #[cfg(test)] mod tests { - use std::io::Read; - use std::str::FromStr; - - #[allow(unused_imports)] use super::*; use brotli::BrotliDecompress; use flate2::bufread::DeflateDecoder; @@ -121,6 +122,8 @@ mod tests { use http::{Request, StatusCode}; use http_body::Body as _; use hyper::Body; + use std::io::Read; + use std::str::FromStr; use tower::ServiceExt; #[tokio::test]