diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md
index 649ad63a..3245208a 100644
--- a/tower-http/CHANGELOG.md
+++ b/tower-http/CHANGELOG.md
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added
- Add `NormalizePath` middleware
+- Add `ValidateRequest` middleware
## Changed
diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml
index 2b54c2b3..c89afcc5 100644
--- a/tower-http/Cargo.toml
+++ b/tower-http/Cargo.toml
@@ -76,6 +76,7 @@ full = [
"timeout",
"trace",
"util",
+ "validate-request",
]
add-extension = []
@@ -98,6 +99,7 @@ set-status = []
timeout = ["tokio/time"]
trace = ["tracing"]
util = ["tower"]
+validate-request = ["mime"]
compression-br = ["async-compression/brotli", "tokio-util", "tokio"]
compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"]
diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs
index 0b413bae..8f0caef6 100644
--- a/tower-http/src/lib.rs
+++ b/tower-http/src/lib.rs
@@ -23,6 +23,7 @@
//! sensitive_headers::SetSensitiveRequestHeadersLayer,
//! set_header::SetResponseHeaderLayer,
//! trace::TraceLayer,
+//! validate_request::ValidateRequestHeaderLayer,
//! };
//! use tower::{ServiceBuilder, service_fn, make::Shared};
//! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}};
@@ -71,6 +72,8 @@
//! .layer(SetResponseHeaderLayer::overriding(CONTENT_TYPE, content_length_from_response))
//! // Authorize requests using a token
//! .layer(RequireAuthorizationLayer::bearer("passwordlol"))
+//! // Accept only application/json, application/* and */* in a request's ACCEPT header
+//! .layer(ValidateRequestHeaderLayer::accept("application/json"))
//! // Wrap a `Service` in our middleware stack
//! .service_fn(handler);
//!
@@ -319,6 +322,9 @@ mod builder;
#[doc(inline)]
pub use self::builder::ServiceBuilderExt;
+#[cfg(feature = "validate-request")]
+pub mod validate_request;
+
/// The latency unit used to report latencies by middleware.
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs
new file mode 100644
index 00000000..c61c1bed
--- /dev/null
+++ b/tower-http/src/validate_request.rs
@@ -0,0 +1,551 @@
+//! Middleware that validates requests.
+//!
+//! # Example
+//!
+//! ```
+//! use tower_http::validate_request::ValidateRequestHeaderLayer;
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! async fn handle(request: Request
) -> Result, Error> {
+//! Ok(Response::new(Body::empty()))
+//! }
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let mut service = ServiceBuilder::new()
+//! // Require the `Accept` header to be `application/json`, `*/*` or `application/*`
+//! .layer(ValidateRequestHeaderLayer::accept("application/json"))
+//! .service_fn(handle);
+//!
+//! // Requests with the correct value are allowed through
+//! let request = Request::builder()
+//! .header(ACCEPT, "application/json")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::OK, response.status());
+//!
+//! // Requests with an invalid value get a `406 Not Acceptable` response
+//! let request = Request::builder()
+//! .header(ACCEPT, "text/strings")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status());
+//! # Ok(())
+//! # }
+//! ```
+//!
+//! Custom validation can be made by implementing [`ValidateRequest`]:
+//!
+//! ```
+//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! #[derive(Clone, Copy)]
+//! pub struct MyHeader { /* ... */ }
+//!
+//! impl ValidateRequest for MyHeader {
+//! type ResponseBody = Body;
+//!
+//! fn validate(
+//! &mut self,
+//! request: &mut Request,
+//! ) -> Result<(), Response> {
+//! // validate the request...
+//! # unimplemented!()
+//! }
+//! }
+//!
+//! async fn handle(request: Request) -> Result, Error> {
+//! Ok(Response::new(Body::empty()))
+//! }
+//!
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let service = ServiceBuilder::new()
+//! // Validate requests using `MyHeader`
+//! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ }))
+//! .service_fn(handle);
+//! # Ok(())
+//! # }
+//! ```
+//!
+//! Or using a closure:
+//!
+//! ```
+//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! async fn handle(request: Request) -> Result, Error> {
+//! # todo!();
+//! // ...
+//! }
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let service = ServiceBuilder::new()
+//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| {
+//! // Validate the request
+//! # Ok::<_, Response>(())
+//! }))
+//! .service_fn(handle);
+//! # Ok(())
+//! # }
+//! ```
+
+use http::{header, Request, Response, StatusCode};
+use http_body::Body;
+use mime::Mime;
+use pin_project_lite::pin_project;
+use std::{
+ fmt,
+ future::Future,
+ marker::PhantomData,
+ pin::Pin,
+ sync::Arc,
+ task::{Context, Poll},
+};
+use tower_layer::Layer;
+use tower_service::Service;
+
+/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
+///
+/// See the [module docs](crate::validate_request) for an example.
+#[derive(Debug, Clone)]
+pub struct ValidateRequestHeaderLayer {
+ validate: T,
+}
+
+impl ValidateRequestHeaderLayer> {
+ /// Validate requests have the required Accept header.
+ ///
+ /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
+ /// as configured.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
+ /// See `AcceptHeader::new` for when this method panics.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use hyper::Body;
+ /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer};
+ ///
+ /// let layer = ValidateRequestHeaderLayer::>::accept("application/json");
+ /// ```
+ ///
+ /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
+ pub fn accept(value: &str) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self::custom(AcceptHeader::new(value))
+ }
+}
+
+impl ValidateRequestHeaderLayer {
+ /// Validate requests using a custom method.
+ pub fn custom(validate: T) -> ValidateRequestHeaderLayer {
+ Self { validate }
+ }
+}
+
+impl Layer for ValidateRequestHeaderLayer
+where
+ T: Clone,
+{
+ type Service = ValidateRequestHeader;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ ValidateRequestHeader::new(inner, self.validate.clone())
+ }
+}
+
+/// Middleware that validates requests.
+///
+/// See the [module docs](crate::validate_request) for an example.
+#[derive(Clone, Debug)]
+pub struct ValidateRequestHeader {
+ inner: S,
+ validate: T,
+}
+
+impl ValidateRequestHeader {
+ fn new(inner: S, validate: T) -> Self {
+ Self::custom(inner, validate)
+ }
+
+ define_inner_service_accessors!();
+}
+
+impl ValidateRequestHeader> {
+ /// Validate requests have the required Accept header.
+ ///
+ /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
+ /// as configured.
+ ///
+ /// # Panics
+ ///
+ /// See `AcceptHeader::new` for when this method panics.
+ pub fn accept(inner: S, value: &str) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self::custom(inner, AcceptHeader::new(value))
+ }
+}
+
+impl ValidateRequestHeader {
+ /// Validate requests using a custom method.
+ pub fn custom(inner: S, validate: T) -> ValidateRequestHeader {
+ Self { inner, validate }
+ }
+}
+
+impl Service> for ValidateRequestHeader
+where
+ V: ValidateRequest,
+ S: Service, Response = Response>,
+{
+ type Response = Response;
+ type Error = S::Error;
+ type Future = ResponseFuture;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx)
+ }
+
+ fn call(&mut self, mut req: Request) -> Self::Future {
+ match self.validate.validate(&mut req) {
+ Ok(_) => ResponseFuture::future(self.inner.call(req)),
+ Err(res) => ResponseFuture::invalid_header_value(res),
+ }
+ }
+}
+
+pin_project! {
+ /// Response future for [`ValidateRequestHeader`].
+ pub struct ResponseFuture {
+ #[pin]
+ kind: Kind,
+ }
+}
+
+impl ResponseFuture {
+ fn future(future: F) -> Self {
+ Self {
+ kind: Kind::Future { future },
+ }
+ }
+
+ fn invalid_header_value(res: Response) -> Self {
+ Self {
+ kind: Kind::Error {
+ response: Some(res),
+ },
+ }
+ }
+}
+
+pin_project! {
+ #[project = KindProj]
+ enum Kind {
+ Future {
+ #[pin]
+ future: F,
+ },
+ Error {
+ response: Option>,
+ },
+ }
+}
+
+impl Future for ResponseFuture
+where
+ F: Future