From 54fec129c7c6160d5ac92560effc6d6a99f130c4 Mon Sep 17 00:00:00 2001
From: 82marbag <69267416+82marbag@users.noreply.github.com>
Date: Mon, 15 Aug 2022 12:01:18 -0700
Subject: [PATCH 01/10] Add layer to validate requests
---
tower-http/src/lib.rs | 2 +
tower-http/src/validate_request.rs | 514 +++++++++++++++++++++++++++++
2 files changed, 516 insertions(+)
create mode 100644 tower-http/src/validate_request.rs
diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs
index 0b413bae..9d78e53b 100644
--- a/tower-http/src/lib.rs
+++ b/tower-http/src/lib.rs
@@ -319,6 +319,8 @@ mod builder;
#[doc(inline)]
pub use self::builder::ServiceBuilderExt;
+pub mod validate_request;
+
/// The latency unit used to report latencies by middleware.
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs
new file mode 100644
index 00000000..3bafc618
--- /dev/null
+++ b/tower-http/src/validate_request.rs
@@ -0,0 +1,514 @@
+//! Middleware that validates the requests a service can handle.
+//!
+//! # Example
+//!
+//! ```
+//! use tower_http::validate_request::ValidateRequestHeaderLayer;
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! async fn handle(request: Request
) -> Result, Error> {
+//! Ok(Response::new(Body::empty()))
+//! }
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let mut service = ServiceBuilder::new()
+//! // Require the `Accept` header to be `application/json`, `*/*` or `application/*`
+//! .layer(ValidateRequestHeaderLayer::accept("application/json"))
+//! .service_fn(handle);
+//!
+//! // Requests with the correct value are allowed through
+//! let request = Request::builder()
+//! .header(ACCEPT, "application/json")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::OK, response.status());
+//!
+//! // Requests with an invalid value get a `406 Not Acceptable` response
+//! let request = Request::builder()
+//! .header(ACCEPT, "text/strings")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status());
+//! # Ok(())
+//! # }
+//! ```
+//!
+//! Custom validation can be made by implementing [`ValidateRequest`]:
+//!
+//! ```
+//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! #[derive(Clone, Copy)]
+//! pub struct MyHeader { }
+//!
+//! impl ValidateRequest for MyHeader {
+//! type ResponseBody = Body;
+//!
+//! fn validate(
+//! &mut self,
+//! request: &mut Request,
+//! ) -> Result<(), Response> {
+//! # unimplemented!()
+//! }
+//! }
+//!
+//! async fn handle(request: Request) -> Result, Error> {
+//! Ok(Response::new(Body::empty()))
+//! }
+//!
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let service = ServiceBuilder::new()
+//! // Validate requests using `MyHeader`
+//! .layer(ValidateRequestHeaderLayer::custom(MyHeader{}))
+//! .service_fn(handle);
+//! # Ok(())
+//! # }
+//! ```
+//!
+//! Or using a closure:
+//!
+//! ```
+//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
+//! use hyper::{Request, Response, Body, Error};
+//! use http::{StatusCode, header::ACCEPT};
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! async fn handle(request: Request) -> Result, Error> {
+//! # todo!();
+//! // ...
+//! }
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let service = ServiceBuilder::new()
+//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| {
+//! // Validate the request
+//! # Ok::<_, Response>(())
+//! }))
+//! .service_fn(handle);
+//! # Ok(())
+//! # }
+//! ```
+
+use http::{
+ header::{self,HeaderValue},
+ Request, Response, StatusCode,
+};
+use http_body::Body;
+use pin_project_lite::pin_project;
+use std::{
+ fmt,
+ future::Future,
+ marker::PhantomData,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tower_layer::Layer;
+use tower_service::Service;
+
+/// Layer that applies [`ValidateRequestHeader`] which validates all requests using the
+/// [`ValidateRequest`] header.
+#[derive(Debug, Clone)]
+pub struct ValidateRequestHeaderLayer {
+ valid: T,
+}
+
+impl ValidateRequestHeaderLayer> {
+ /// Validate requests have the required Accept header.
+ ///
+ /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
+ /// as configured.
+ ///
+ /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
+ pub fn accept(value: &str) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self::custom(AcceptHeader::new(value))
+ }
+}
+
+impl ValidateRequestHeaderLayer {
+ /// Validate requests using a custom method.
+ pub fn custom(valid: T) -> ValidateRequestHeaderLayer {
+ Self { valid }
+ }
+}
+
+impl Layer for ValidateRequestHeaderLayer
+where
+ T: Clone,
+{
+ type Service = ValidateRequestHeader;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ ValidateRequestHeader::new(inner, self.valid.clone())
+ }
+}
+
+/// Middleware that validates requests.
+#[derive(Clone, Debug)]
+pub struct ValidateRequestHeader {
+ inner: S,
+ valid: T,
+}
+
+impl ValidateRequestHeader {
+ fn new(inner: S, valid: T) -> Self {
+ Self { inner, valid }
+ }
+
+ define_inner_service_accessors!();
+}
+
+impl ValidateRequestHeader> {
+ /// Validate requests have the required Accept header.
+ ///
+ /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
+ /// as configured.
+ pub fn accept(inner: S, value: &str) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self::custom(inner,AcceptHeader::new(value))
+ }
+}
+
+impl ValidateRequestHeader {
+ /// Validate requests using a custom method.
+ pub fn custom(inner: S, valid: T) -> ValidateRequestHeader {
+ Self { inner, valid }
+ }
+}
+
+impl Service> for ValidateRequestHeader
+where
+ V: ValidateRequest,
+ S: Service, Response = Response>,
+{
+ type Response = Response;
+ type Error = S::Error;
+ type Future = ResponseFuture;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx)
+ }
+
+ fn call(&mut self, mut req: Request) -> Self::Future {
+ match self.valid.validate(&mut req) {
+ Ok(_) => ResponseFuture::future(self.inner.call(req)),
+ Err(res) => ResponseFuture::invalid_header_value(res),
+ }
+ }
+}
+
+pin_project! {
+ /// Response future for [`ValidateRequestHeader`].
+ pub struct ResponseFuture {
+ #[pin]
+ kind: Kind,
+ }
+}
+
+impl ResponseFuture {
+ fn future(future: F) -> Self {
+ Self {
+ kind: Kind::Future { future },
+ }
+ }
+
+ fn invalid_header_value(res: Response) -> Self {
+ Self {
+ kind: Kind::Error {
+ response: Some(res),
+ },
+ }
+ }
+}
+
+pin_project! {
+ #[project = KindProj]
+ enum Kind {
+ Future {
+ #[pin]
+ future: F,
+ },
+ Error {
+ response: Option>,
+ },
+ }
+}
+
+impl Future for ResponseFuture
+where
+ F: Future