From 54fec129c7c6160d5ac92560effc6d6a99f130c4 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Mon, 15 Aug 2022 12:01:18 -0700 Subject: [PATCH 01/10] Add layer to validate requests --- tower-http/src/lib.rs | 2 + tower-http/src/validate_request.rs | 514 +++++++++++++++++++++++++++++ 2 files changed, 516 insertions(+) create mode 100644 tower-http/src/validate_request.rs diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 0b413bae..9d78e53b 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -319,6 +319,8 @@ mod builder; #[doc(inline)] pub use self::builder::ServiceBuilderExt; +pub mod validate_request; + /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs new file mode 100644 index 00000000..3bafc618 --- /dev/null +++ b/tower-http/src/validate_request.rs @@ -0,0 +1,514 @@ +//! Middleware that validates the requests a service can handle. +//! +//! # Example +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let mut service = ServiceBuilder::new() +//! // Require the `Accept` header to be `application/json`, `*/*` or `application/*` +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header(ACCEPT, "application/json") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `406 Not Acceptable` response +//! let request = Request::builder() +//! .header(ACCEPT, "text/strings") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status()); +//! # Ok(()) +//! # } +//! ``` +//! +//! Custom validation can be made by implementing [`ValidateRequest`]: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! #[derive(Clone, Copy)] +//! pub struct MyHeader { } +//! +//! impl ValidateRequest for MyHeader { +//! type ResponseBody = Body; +//! +//! fn validate( +//! &mut self, +//! request: &mut Request, +//! ) -> Result<(), Response> { +//! # unimplemented!() +//! } +//! } +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! // Validate requests using `MyHeader` +//! .layer(ValidateRequestHeaderLayer::custom(MyHeader{})) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! Or using a closure: +//! +//! ``` +//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::ACCEPT}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! # todo!(); +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| { +//! // Validate the request +//! # Ok::<_, Response>(()) +//! })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use http::{ + header::{self,HeaderValue}, + Request, Response, StatusCode, +}; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`ValidateRequestHeader`] which validates all requests using the +/// [`ValidateRequest`] header. +#[derive(Debug, Clone)] +pub struct ValidateRequestHeaderLayer { + valid: T, +} + +impl ValidateRequestHeaderLayer> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + /// + /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept + pub fn accept(value: &str) -> Self + where + ResBody: Body + Default, + { + Self::custom(AcceptHeader::new(value)) + } +} + +impl ValidateRequestHeaderLayer { + /// Validate requests using a custom method. + pub fn custom(valid: T) -> ValidateRequestHeaderLayer { + Self { valid } + } +} + +impl Layer for ValidateRequestHeaderLayer +where + T: Clone, +{ + type Service = ValidateRequestHeader; + + fn layer(&self, inner: S) -> Self::Service { + ValidateRequestHeader::new(inner, self.valid.clone()) + } +} + +/// Middleware that validates requests. +#[derive(Clone, Debug)] +pub struct ValidateRequestHeader { + inner: S, + valid: T, +} + +impl ValidateRequestHeader { + fn new(inner: S, valid: T) -> Self { + Self { inner, valid } + } + + define_inner_service_accessors!(); +} + +impl ValidateRequestHeader> { + /// Validate requests have the required Accept header. + /// + /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, + /// as configured. + pub fn accept(inner: S, value: &str) -> Self + where + ResBody: Body + Default, + { + Self::custom(inner,AcceptHeader::new(value)) + } +} + +impl ValidateRequestHeader { + /// Validate requests using a custom method. + pub fn custom(inner: S, valid: T) -> ValidateRequestHeader { + Self { inner, valid } + } +} + +impl Service> for ValidateRequestHeader +where + V: ValidateRequest, + S: Service, Response = Response>, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + match self.valid.validate(&mut req) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(res) => ResponseFuture::invalid_header_value(res), + } + } +} + +pin_project! { + /// Response future for [`ValidateRequestHeader`]. + pub struct ResponseFuture { + #[pin] + kind: Kind, + } +} + +impl ResponseFuture { + fn future(future: F) -> Self { + Self { + kind: Kind::Future { future }, + } + } + + fn invalid_header_value(res: Response) -> Self { + Self { + kind: Kind::Error { + response: Some(res), + }, + } + } +} + +pin_project! { + #[project = KindProj] + enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().unwrap(); + Poll::Ready(Ok(response)) + } + } + } +} + +/// Trait for validating requests. +pub trait ValidateRequest { + /// The body type used for responses to unvalidated requests. + type ResponseBody; + + /// Validate the request. + /// + /// If `Ok(())` is returned then the request is allowed through, otherwise not. + fn validate(&mut self, request: &mut Request) -> Result<(), Response>; +} + +impl ValidateRequest for F +where + F: FnMut(&mut Request) -> Result<(), Response>, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request) -> Result<(), Response> { + self(request) + } +} + +/// Type that performs validation of the Accept header. +pub struct AcceptHeader { + header_value: HeaderValue, + _ty: PhantomData ResBody>, +} + +impl AcceptHeader { + fn new(header_value: &str) -> Self + where + ResBody: Body + Default, + { + Self { + header_value: header_value + .parse() + .expect("token is not a valid header value"), + _ty: PhantomData, + } + } +} + +impl Clone for AcceptHeader { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for AcceptHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AcceptHeader") + .field("header_value", &self.header_value) + .finish() + } +} + +impl ValidateRequest for AcceptHeader +where + ResBody: Body + Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request) -> Result<(), Response> { + if !req.headers().contains_key(header::ACCEPT) { + return Ok(()) + } + if req.headers().get_all(header::ACCEPT) + .into_iter() + .flat_map(|header| header.to_str().ok().into_iter().flat_map(|s| s.split(",").map(|typ| typ.trim()))) + .any(|h| { + let value = self.header_value.to_str().unwrap(); + let primary = format!("{}/*", value.split("/").nth(0).unwrap()); + h == "*/*" || h == primary || h == value + }) { + return Ok(()) + } + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::NOT_ACCEPTABLE; + Err(res) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use http::header; + use hyper::Body; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn valid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "application/json" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all_json() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "application/*" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn valid_accept_header_accept_all() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "*/*" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn invalid_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "invalid" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + #[tokio::test] + async fn not_accepted_accept_header() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "text/strings" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + + #[tokio::test] + async fn accepted_multiple_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "text/strings" + ) + .header( + header::ACCEPT, + "invalid, application/json" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn accepted_inner_header_value() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header( + header::ACCEPT, + "text/strings, invalid, application/json" + ) + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} + From 2aea715cb93e767c42004bd9a5610c07483b0b5f Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Mon, 15 Aug 2022 15:31:25 -0700 Subject: [PATCH 02/10] Fix style --- tower-http/src/validate_request.rs | 71 ++++++++++++------------------ 1 file changed, 28 insertions(+), 43 deletions(-) diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 3bafc618..67878bf8 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -113,7 +113,7 @@ //! ``` use http::{ - header::{self,HeaderValue}, + header::{self, HeaderValue}, Request, Response, StatusCode, }; use http_body::Body; @@ -192,7 +192,7 @@ impl ValidateRequestHeader> { where ResBody: Body + Default, { - Self::custom(inner,AcceptHeader::new(value)) + Self::custom(inner, AcceptHeader::new(value)) } } @@ -345,18 +345,27 @@ where fn validate(&mut self, req: &mut Request) -> Result<(), Response> { if !req.headers().contains_key(header::ACCEPT) { - return Ok(()) + return Ok(()); } - if req.headers().get_all(header::ACCEPT) + if req + .headers() + .get_all(header::ACCEPT) .into_iter() - .flat_map(|header| header.to_str().ok().into_iter().flat_map(|s| s.split(",").map(|typ| typ.trim()))) + .flat_map(|header| { + header + .to_str() + .ok() + .into_iter() + .flat_map(|s| s.split(",").map(|typ| typ.trim())) + }) .any(|h| { - let value = self.header_value.to_str().unwrap(); - let primary = format!("{}/*", value.split("/").nth(0).unwrap()); - h == "*/*" || h == primary || h == value - }) { - return Ok(()) - } + let value = self.header_value.to_str().unwrap(); + let primary = format!("{}/*", value.split("/").nth(0).unwrap()); + h == "*/*" || h == primary || h == value + }) + { + return Ok(()); + } let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::NOT_ACCEPTABLE; Err(res) @@ -378,10 +387,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "application/json" - ) + .header(header::ACCEPT, "application/json") .body(Body::empty()) .unwrap(); @@ -397,10 +403,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "application/*" - ) + .header(header::ACCEPT, "application/*") .body(Body::empty()) .unwrap(); @@ -416,10 +419,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "*/*" - ) + .header(header::ACCEPT, "*/*") .body(Body::empty()) .unwrap(); @@ -435,10 +435,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "invalid" - ) + .header(header::ACCEPT, "invalid") .body(Body::empty()) .unwrap(); @@ -453,10 +450,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "text/strings" - ) + .header(header::ACCEPT, "text/strings") .body(Body::empty()) .unwrap(); @@ -472,14 +466,8 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "text/strings" - ) - .header( - header::ACCEPT, - "invalid, application/json" - ) + .header(header::ACCEPT, "text/strings") + .header(header::ACCEPT, "invalid, application/json") .body(Body::empty()) .unwrap(); @@ -495,10 +483,7 @@ mod tests { .service_fn(echo); let request = Request::get("/") - .header( - header::ACCEPT, - "text/strings, invalid, application/json" - ) + .header(header::ACCEPT, "text/strings, invalid, application/json") .body(Body::empty()) .unwrap(); From 3ab01c562ac77e23a8a91b11be1b3ce8a0e56364 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:13:34 -0700 Subject: [PATCH 03/10] Address comments --- tower-http/src/validate_request.rs | 66 ++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 67878bf8..a1c897fb 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -59,7 +59,7 @@ //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! //! #[derive(Clone, Copy)] -//! pub struct MyHeader { } +//! pub struct MyHeader { /* ... */ } //! //! impl ValidateRequest for MyHeader { //! type ResponseBody = Body; @@ -68,6 +68,7 @@ //! &mut self, //! request: &mut Request, //! ) -> Result<(), Response> { +//! // validate the request... //! # unimplemented!() //! } //! } @@ -81,7 +82,7 @@ //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() //! // Validate requests using `MyHeader` -//! .layer(ValidateRequestHeaderLayer::custom(MyHeader{})) +//! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ })) //! .service_fn(handle); //! # Ok(()) //! # } @@ -128,11 +129,12 @@ use std::{ use tower_layer::Layer; use tower_service::Service; -/// Layer that applies [`ValidateRequestHeader`] which validates all requests using the -/// [`ValidateRequest`] header. +/// Layer that applies [`ValidateRequestHeader`] which validates all requests. +/// +/// See the [module docs](crate::validate_request) for an example. #[derive(Debug, Clone)] pub struct ValidateRequestHeaderLayer { - valid: T, + validate: T, } impl ValidateRequestHeaderLayer> { @@ -152,8 +154,8 @@ impl ValidateRequestHeaderLayer> { impl ValidateRequestHeaderLayer { /// Validate requests using a custom method. - pub fn custom(valid: T) -> ValidateRequestHeaderLayer { - Self { valid } + pub fn custom(validate: T) -> ValidateRequestHeaderLayer { + Self { validate } } } @@ -164,20 +166,22 @@ where type Service = ValidateRequestHeader; fn layer(&self, inner: S) -> Self::Service { - ValidateRequestHeader::new(inner, self.valid.clone()) + ValidateRequestHeader::new(inner, self.validate.clone()) } } /// Middleware that validates requests. +/// +/// See the [module docs](crate::validate_request) for an example. #[derive(Clone, Debug)] pub struct ValidateRequestHeader { inner: S, - valid: T, + validate: T, } impl ValidateRequestHeader { - fn new(inner: S, valid: T) -> Self { - Self { inner, valid } + fn new(inner: S, validate: T) -> Self { + Self::custom(inner, validate) } define_inner_service_accessors!(); @@ -198,8 +202,8 @@ impl ValidateRequestHeader> { impl ValidateRequestHeader { /// Validate requests using a custom method. - pub fn custom(inner: S, valid: T) -> ValidateRequestHeader { - Self { inner, valid } + pub fn custom(inner: S, validate: T) -> ValidateRequestHeader { + Self { inner, validate } } } @@ -217,7 +221,7 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - match self.valid.validate(&mut req) { + match self.validate.validate(&mut req) { Ok(_) => ResponseFuture::future(self.inner.call(req)), Err(res) => ResponseFuture::invalid_header_value(res), } @@ -271,6 +275,7 @@ where match self.project().kind.project() { KindProj::Future { future } => future.poll(cx), KindProj::Error { response } => { + /* Never panics unless polled after completion */ let response = response.take().unwrap(); Poll::Ready(Ok(response)) } @@ -303,18 +308,27 @@ where /// Type that performs validation of the Accept header. pub struct AcceptHeader { header_value: HeaderValue, + primary_type: String, _ty: PhantomData ResBody>, } impl AcceptHeader { + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` fn new(header_value: &str) -> Self where ResBody: Body + Default, { + let primary_type = format!( + "{}/*", header_value + .split("/") + .nth(0) + .expect("value is not valid for the Accept header of the form type/subtype") + ); Self { header_value: header_value .parse() - .expect("token is not a valid header value"), + .expect("value is not a valid header value"), + primary_type, _ty: PhantomData, } } @@ -324,6 +338,7 @@ impl Clone for AcceptHeader { fn clone(&self) -> Self { Self { header_value: self.header_value.clone(), + primary_type: self.primary_type.clone(), _ty: PhantomData, } } @@ -359,9 +374,8 @@ where .flat_map(|s| s.split(",").map(|typ| typ.trim())) }) .any(|h| { - let value = self.header_value.to_str().unwrap(); - let primary = format!("{}/*", value.split("/").nth(0).unwrap()); - h == "*/*" || h == primary || h == value + let value = self.header_value.to_str().unwrap(); /* cannot panic, checked at creation time */ + h == "*/*" || h == self.primary_type || h == value }) { return Ok(()); @@ -443,6 +457,22 @@ mod tests { assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } + #[tokio::test] + async fn not_accepted_accept_header_subtype() { + let mut service = ServiceBuilder::new() + .layer(ValidateRequestHeaderLayer::accept("application/json")) + .service_fn(echo); + + let request = Request::get("/") + .header(header::ACCEPT, "application/strings") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + } + #[tokio::test] async fn not_accepted_accept_header() { let mut service = ServiceBuilder::new() From 1b1207206edaa80f20e073094fce4e166cd93465 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:37:21 -0700 Subject: [PATCH 04/10] Use mime --- tower-http/Cargo.toml | 4 ++-- tower-http/src/validate_request.rs | 30 ++++++++++++++++-------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 2b54c2b3..9fbd1c2b 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -19,6 +19,7 @@ futures-core = "0.3" futures-util = { version = "0.3.14", default_features = false, features = [] } http = "0.2.2" http-body = "0.4.5" +mime = { version = "0.3", default_features = false } pin-project-lite = "0.2.7" tower-layer = "0.3" tower-service = "0.3" @@ -28,7 +29,6 @@ async-compression = { version = "0.3", optional = true, features = ["tokio"] } base64 = { version = "0.13", optional = true } http-range-header = "0.3.0" iri-string = { version = "0.4", optional = true } -mime = { version = "0.3", optional = true, default_features = false } mime_guess = { version = "2", optional = true, default_features = false } percent-encoding = { version = "2.1.0", optional = true } tokio = { version = "1.6", optional = true, default_features = false } @@ -83,7 +83,7 @@ auth = ["base64"] catch-panic = ["tracing", "futures-util/std"] cors = [] follow-redirect = ["iri-string", "tower/util"] -fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"] +fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"] limit = [] map-request-body = [] map-response-body = [] diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index a1c897fb..43d3bc97 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -114,10 +114,11 @@ //! ``` use http::{ - header::{self, HeaderValue}, + header::self, Request, Response, StatusCode, }; use http_body::Body; +use mime::Mime; use pin_project_lite::pin_project; use std::{ fmt, @@ -307,8 +308,7 @@ where /// Type that performs validation of the Accept header. pub struct AcceptHeader { - header_value: HeaderValue, - primary_type: String, + header_value: Mime, _ty: PhantomData ResBody>, } @@ -318,17 +318,10 @@ impl AcceptHeader { where ResBody: Body + Default, { - let primary_type = format!( - "{}/*", header_value - .split("/") - .nth(0) - .expect("value is not valid for the Accept header of the form type/subtype") - ); Self { header_value: header_value - .parse() + .parse::() .expect("value is not a valid header value"), - primary_type, _ty: PhantomData, } } @@ -338,7 +331,6 @@ impl Clone for AcceptHeader { fn clone(&self) -> Self { Self { header_value: self.header_value.clone(), - primary_type: self.primary_type.clone(), _ty: PhantomData, } } @@ -374,8 +366,18 @@ where .flat_map(|s| s.split(",").map(|typ| typ.trim())) }) .any(|h| { - let value = self.header_value.to_str().unwrap(); /* cannot panic, checked at creation time */ - h == "*/*" || h == self.primary_type || h == value + h.parse::() + .map(|mim| { + let typ = self.header_value.type_(); + let subtype = self.header_value.subtype(); + match (mim.type_(), mim.subtype()) { + (t, s) if t == typ && s == subtype => true, + (t, mime::STAR) if t == typ => true, + (mime::STAR, mime::STAR) => true, + _ => false + } + }) + .unwrap_or(false) }) { return Ok(()); From 2df20d489487f65a37285f6cf8aeb2b3fdca1ed9 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 16 Aug 2022 15:36:49 -0700 Subject: [PATCH 05/10] Fix style --- tower-http/src/validate_request.rs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 43d3bc97..654dd2a5 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -77,7 +77,7 @@ //! Ok(Response::new(Body::empty())) //! } //! -//! +//! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() @@ -113,10 +113,7 @@ //! # } //! ``` -use http::{ - header::self, - Request, Response, StatusCode, -}; +use http::{header, Request, Response, StatusCode}; use http_body::Body; use mime::Mime; use pin_project_lite::pin_project; @@ -131,7 +128,7 @@ use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`ValidateRequestHeader`] which validates all requests. -/// +/// /// See the [module docs](crate::validate_request) for an example. #[derive(Debug, Clone)] pub struct ValidateRequestHeaderLayer { @@ -172,7 +169,7 @@ where } /// Middleware that validates requests. -/// +/// /// See the [module docs](crate::validate_request) for an example. #[derive(Clone, Debug)] pub struct ValidateRequestHeader { @@ -374,7 +371,7 @@ where (t, s) if t == typ && s == subtype => true, (t, mime::STAR) if t == typ => true, (mime::STAR, mime::STAR) => true, - _ => false + _ => false, } }) .unwrap_or(false) @@ -528,4 +525,3 @@ mod tests { Ok(Response::new(req.into_body())) } } - From a2bfb244715fb301aae2fe7de6898e40a86080b8 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 16 Aug 2022 16:23:42 -0700 Subject: [PATCH 06/10] Update docs --- tower-http/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 9d78e53b..d8d7ca19 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -23,6 +23,7 @@ //! sensitive_headers::SetSensitiveRequestHeadersLayer, //! set_header::SetResponseHeaderLayer, //! trace::TraceLayer, +//! validate_request::ValidateRequestHeaderLayer, //! }; //! use tower::{ServiceBuilder, service_fn, make::Shared}; //! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}}; @@ -71,6 +72,8 @@ //! .layer(SetResponseHeaderLayer::overriding(CONTENT_TYPE, content_length_from_response)) //! // Authorize requests using a token //! .layer(RequireAuthorizationLayer::bearer("passwordlol")) +//! // Accept only application/json, application/* and */* in a request's ACCEPT header +//! .layer(ValidateRequestHeaderLayer::accept("application/json")) //! // Wrap a `Service` in our middleware stack //! .service_fn(handler); //! From a33e30eccb96b9b5e1d55524cc88521da9aa0dc2 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Wed, 17 Aug 2022 19:03:53 -0700 Subject: [PATCH 07/10] validate-request feature --- tower-http/Cargo.toml | 4 +++- tower-http/src/validate_request.rs | 34 +++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 9fbd1c2b..9922f9b1 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -76,6 +76,7 @@ full = [ "timeout", "trace", "util", + "validate-request", ] add-extension = [] @@ -83,7 +84,7 @@ auth = ["base64"] catch-panic = ["tracing", "futures-util/std"] cors = [] follow-redirect = ["iri-string", "tower/util"] -fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"] +fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"] limit = [] map-request-body = [] map-response-body = [] @@ -98,6 +99,7 @@ set-status = [] timeout = ["tokio/time"] trace = ["tracing"] util = ["tower"] +validate-request = ["mime"] compression-br = ["async-compression/brotli", "tokio-util", "tokio"] compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 654dd2a5..e91847e4 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -1,4 +1,4 @@ -//! Middleware that validates the requests a service can handle. +//! Middleware that validates requests. //! //! # Example //! @@ -122,6 +122,7 @@ use std::{ future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; @@ -141,6 +142,20 @@ impl ValidateRequestHeaderLayer> { /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, /// as configured. /// + /// # Panics + /// + /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` + /// See `AcceptHeader::new` for when this method panics. + /// + /// # Example + /// + /// ``` + /// use hyper::Body; + /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer}; + /// + /// let layer = ValidateRequestHeaderLayer::>::accept("application/json"); + /// ``` + /// /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept pub fn accept(value: &str) -> Self where @@ -190,6 +205,10 @@ impl ValidateRequestHeader> { /// /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, /// as configured. + /// + /// # Panics + /// + /// See `AcceptHeader::new` for when this method panics. pub fn accept(inner: S, value: &str) -> Self where ResBody: Body + Default, @@ -274,7 +293,7 @@ where KindProj::Future { future } => future.poll(cx), KindProj::Error { response } => { /* Never panics unless polled after completion */ - let response = response.take().unwrap(); + let response = response.take().expect("future polled after completion"); Poll::Ready(Ok(response)) } } @@ -305,20 +324,25 @@ where /// Type that performs validation of the Accept header. pub struct AcceptHeader { - header_value: Mime, + header_value: Arc, _ty: PhantomData ResBody>, } impl AcceptHeader { + /// Create a new `AcceptHeader`. + /// + /// # Panics + /// /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` fn new(header_value: &str) -> Self where ResBody: Body + Default, { Self { - header_value: header_value + header_value: Arc::new(header_value .parse::() - .expect("value is not a valid header value"), + .expect("value is not a valid header value") + ), _ty: PhantomData, } } From 780377c51f8a4450880c3efa61e97774ac8486a2 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Thu, 18 Aug 2022 12:29:16 -0400 Subject: [PATCH 08/10] Update changelog --- tower-http/CHANGELOG.md | 1 + tower-http/Cargo.toml | 2 +- tower-http/src/validate_request.rs | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 649ad63a..3245208a 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added - Add `NormalizePath` middleware +- Add `ValidateRequest` middleware ## Changed diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 9922f9b1..c89afcc5 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -19,7 +19,6 @@ futures-core = "0.3" futures-util = { version = "0.3.14", default_features = false, features = [] } http = "0.2.2" http-body = "0.4.5" -mime = { version = "0.3", default_features = false } pin-project-lite = "0.2.7" tower-layer = "0.3" tower-service = "0.3" @@ -29,6 +28,7 @@ async-compression = { version = "0.3", optional = true, features = ["tokio"] } base64 = { version = "0.13", optional = true } http-range-header = "0.3.0" iri-string = { version = "0.4", optional = true } +mime = { version = "0.3", optional = true, default_features = false } mime_guess = { version = "2", optional = true, default_features = false } percent-encoding = { version = "2.1.0", optional = true } tokio = { version = "1.6", optional = true, default_features = false } diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index e91847e4..8e052d4f 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -292,7 +292,6 @@ where match self.project().kind.project() { KindProj::Future { future } => future.poll(cx), KindProj::Error { response } => { - /* Never panics unless polled after completion */ let response = response.take().expect("future polled after completion"); Poll::Ready(Ok(response)) } From 20d80255d86f2a515d45c22f44f9418fa932e149 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Thu, 18 Aug 2022 12:36:35 -0400 Subject: [PATCH 09/10] Fix feature gate --- tower-http/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index d8d7ca19..8f0caef6 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -322,6 +322,7 @@ mod builder; #[doc(inline)] pub use self::builder::ServiceBuilderExt; +#[cfg(feature = "validate-request")] pub mod validate_request; /// The latency unit used to report latencies by middleware. From 00f0d9b9f2886e40d8220861fee40f84390f43ff Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Thu, 18 Aug 2022 15:40:23 -0400 Subject: [PATCH 10/10] Fix style --- tower-http/src/validate_request.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 8e052d4f..c61c1bed 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -338,9 +338,10 @@ impl AcceptHeader { ResBody: Body + Default, { Self { - header_value: Arc::new(header_value - .parse::() - .expect("value is not a valid header value") + header_value: Arc::new( + header_value + .parse::() + .expect("value is not a valid header value"), ), _ty: PhantomData, }