diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 649ad63a..3245208a 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added - Add `NormalizePath` middleware +- Add `ValidateRequest` middleware ## Changed diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 2b54c2b3..c89afcc5 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -76,6 +76,7 @@ full = [ "timeout", "trace", "util", + "validate-request", ] add-extension = [] @@ -98,6 +99,7 @@ set-status = [] timeout = ["tokio/time"] trace = ["tracing"] util = ["tower"] +validate-request = ["mime"] compression-br = ["async-compression/brotli", "tokio-util", "tokio"] compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 0b413bae..8f0caef6 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -23,6 +23,7 @@ //! sensitive_headers::SetSensitiveRequestHeadersLayer, //! set_header::SetResponseHeaderLayer, //! trace::TraceLayer, +//! validate_request::ValidateRequestHeaderLayer, //! }; //! use tower::{ServiceBuilder, service_fn, make::Shared}; //! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}}; @@ -71,6 +72,8 @@ //! .layer(SetResponseHeaderLayer::overriding(CONTENT_TYPE, content_length_from_response)) //! // Authorize requests using a token //! .layer(RequireAuthorizationLayer::bearer("passwordlol")) +//! // Accept only application/json, application/* and */* in a request's ACCEPT header +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) //! // Wrap a `Service` in our middleware stack //! .service_fn(handler); //! @@ -319,6 +322,9 @@ mod builder; #[doc(inline)] pub use self::builder::ServiceBuilderExt; +#[cfg(feature = "validate-request")] +pub mod validate_request; + /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs new file mode 100644 index 00000000..c61c1bed --- /dev/null +++ b/tower-http/src/validate_request.rs @@ -0,0 +1,551 @@ +//! Middleware that validates requests. +//! +//! # Example +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let mut service = ServiceBuilder::new() +//! // Require the `Accept` header to be `application/json`, `*/*` or `application/*` +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header(ACCEPT, "application/json") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `406 Not Acceptable` response +//! let request = Request::builder() +//! .header(ACCEPT, "text/strings") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status()); +//! # Ok(()) +//! # } +//! ``` +//! +//! Custom validation can be made by implementing [`ValidateRequest`]: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! #[derive(Clone, Copy)] +//! pub struct MyHeader { /* ... */ } +//! +//! impl ValidateRequest for MyHeader { +//! type ResponseBody = Body; +//! +//! fn validate( +//! &mut self, +//! request: &mut Request, +//! ) -> Result<(), Response> { +//! // validate the request... +//! # unimplemented!() +//! } +//! } +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! // Validate requests using `MyHeader` +//! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! Or using a closure: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! # todo!(); +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| { +//! // Validate the request +//! # Ok::<_, Response>(()) +//! })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use http::{header, Request, Response, StatusCode}; +use http_body::Body; +use mime::Mime; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`ValidateRequestHeader`] which validates all requests. +/// +/// See the [module docs](crate::validate_request) for an example. +#[derive(Debug, Clone)] +pub struct ValidateRequestHeaderLayer { + validate: T, +} + +impl ValidateRequestHeaderLayer> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` + /// See `AcceptHeader::new` for when this method panics. + /// + /// # Example + /// + /// ``` + /// use hyper::Body; + /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer}; + /// + /// let layer = ValidateRequestHeaderLayer::>::accept("application/json"); + /// ``` + /// + /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept + pub fn accept(value: &str) -> Self + where + ResBody: Body + Default, + { + Self::custom(AcceptHeader::new(value)) + } +} + +impl ValidateRequestHeaderLayer { + /// Validate requests using a custom method. + pub fn custom(validate: T) -> ValidateRequestHeaderLayer { + Self { validate } + } +} + +impl Layer for ValidateRequestHeaderLayer +where + T: Clone, +{ + type Service = ValidateRequestHeader; + + fn layer(&self, inner: S) -> Self::Service { + ValidateRequestHeader::new(inner, self.validate.clone()) + } +} + +/// Middleware that validates requests. +/// +/// See the [module docs](crate::validate_request) for an example. +#[derive(Clone, Debug)] +pub struct ValidateRequestHeader { + inner: S, + validate: T, +} + +impl ValidateRequestHeader { + fn new(inner: S, validate: T) -> Self { + Self::custom(inner, validate) + } + + define_inner_service_accessors!(); +} + +impl ValidateRequestHeader> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// # Panics + /// + /// See `AcceptHeader::new` for when this method panics. + pub fn accept(inner: S, value: &str) -> Self + where + ResBody: Body + Default, + { + Self::custom(inner, AcceptHeader::new(value)) + } +} + +impl ValidateRequestHeader { + /// Validate requests using a custom method. + pub fn custom(inner: S, validate: T) -> ValidateRequestHeader { + Self { inner, validate } + } +} + +impl Service> for ValidateRequestHeader +where + V: ValidateRequest, + S: Service, Response = Response>, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + match self.validate.validate(&mut req) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(res) => ResponseFuture::invalid_header_value(res), + } + } +} + +pin_project! { + /// Response future for [`ValidateRequestHeader`]. + pub struct ResponseFuture { + #[pin] + kind: Kind, + } +} + +impl ResponseFuture { + fn future(future: F) -> Self { + Self { + kind: Kind::Future { future }, + } + } + + fn invalid_header_value(res: Response) -> Self { + Self { + kind: Kind::Error { + response: Some(res), + }, + } + } +} + +pin_project! { + #[project = KindProj] + enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().expect("future polled after completion"); + Poll::Ready(Ok(response)) + } + } + } +} + +/// Trait for validating requests. +pub trait ValidateRequest { + /// The body type used for responses to unvalidated requests. + type ResponseBody; + + /// Validate the request. + /// + /// If `Ok(())` is returned then the request is allowed through, otherwise not. + fn validate(&mut self, request: &mut Request) -> Result<(), Response>; +} + +impl ValidateRequest for F +where + F: FnMut(&mut Request) -> Result<(), Response>, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request) -> Result<(), Response> { + self(request) + } +} + +/// Type that performs validation of the Accept header. +pub struct AcceptHeader { + header_value: Arc, + _ty: PhantomData ResBody>, +} + +impl AcceptHeader { + /// Create a new `AcceptHeader`. + /// + /// # Panics + /// + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` + fn new(header_value: &str) -> Self + where + ResBody: Body + Default, + { + Self { + header_value: Arc::new( + header_value + .parse::() + .expect("value is not a valid header value"), + ), + _ty: PhantomData, + } + } +} + +impl Clone for AcceptHeader { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for AcceptHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AcceptHeader") + .field("header_value", &self.header_value) + .finish() + } +} + +impl ValidateRequest for AcceptHeader +where + ResBody: Body + Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request) -> Result<(), Response> { + if !req.headers().contains_key(header::ACCEPT) { + return Ok(()); + } + if req + .headers() + .get_all(header::ACCEPT) + .into_iter() + .flat_map(|header| { + header + .to_str() + .ok() + .into_iter() + .flat_map(|s| s.split(",").map(|typ| typ.trim())) + }) + .any(|h| { + h.parse::() + .map(|mim| { + let typ = self.header_value.type_(); + let subtype = self.header_value.subtype(); + match (mim.type_(), mim.subtype()) { + (t, s) if t == typ && s == subtype => true, + (t, mime::STAR) if t == typ => true, + (mime::STAR, mime::STAR) => true, + _ => false, + } + }) + .unwrap_or(false) + }) + { + return Ok(()); + } + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::NOT_ACCEPTABLE; + Err(res) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use http::header; + use hyper::Body; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn valid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all_json() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/*") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "*/*") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn invalid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "invalid") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + #[tokio::test] + async fn not_accepted_accept_header_subtype() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/strings") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn not_accepted_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn accepted_multiple_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings") + .header(header::ACCEPT, "invalid, application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn accepted_inner_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "text/strings, invalid, application/json") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +}