From 979876e7f83809cf7e84f62af31dec2e92cee703 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 26 Mar 2022 15:05:53 +0100 Subject: [PATCH 01/12] cors: Remove Kind::InvalidOrigin It was never constructed. --- tower-http/src/cors.rs | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tower-http/src/cors.rs b/tower-http/src/cors.rs index 9e11cb61..d6b405d8 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors.rs @@ -46,6 +46,8 @@ //! //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS +#![allow(clippy::enum_variant_names)] + use bytes::{BufMut, BytesMut}; use futures_core::ready; use http::{ @@ -747,7 +749,6 @@ pin_project! { headers: HeaderMap, }, InvalidCorsCall, - InvalidOrigin, } } @@ -783,14 +784,6 @@ where Poll::Ready(Ok(response)) } - KindProj::InvalidOrigin => { - let response = Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(B::default()) - .unwrap(); - - Poll::Ready(Ok(response)) - } } } } From ed48d6f771eb9a9dc06c0d0b8e3237bd83355ea6 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Mon, 21 Mar 2022 11:38:42 +0100 Subject: [PATCH 02/12] cors: Don't special-case preflight requests with non-matching CORS method --- tower-http/src/cors.rs | 65 +----------------------------------------- 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/tower-http/src/cors.rs b/tower-http/src/cors.rs index d6b405d8..f72d9a69 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors.rs @@ -53,7 +53,7 @@ use futures_core::ready; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts, - HeaderMap, Method, Request, Response, StatusCode, + HeaderMap, Method, Request, Response, }; use pin_project_lite::pin_project; use std::{ @@ -524,22 +524,6 @@ impl Cors { self } - fn is_valid_request_method(&self, method: &HeaderValue) -> bool { - if let Some(allow_methods) = &self.layer.allow_methods { - #[allow(clippy::borrow_interior_mutable_const)] - if allow_methods == WILDCARD { - return true; - } - - allow_methods - .as_bytes() - .split(|&byte| byte == b',') - .any(|bytes| bytes == method.as_bytes()) - } else { - false - } - } - fn make_response_header_map(&self) -> HeaderMap { #[allow(clippy::declare_interior_mutable_const)] const TRUE: HeaderValue = HeaderValue::from_static("true"); @@ -689,16 +673,6 @@ where // Return results immediately upon preflight request if parts.method == Method::OPTIONS { - // the method the real request will be made with - match parts.headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) { - Some(request_method) if self.is_valid_request_method(request_method) => {} - _ => { - return ResponseFuture { - inner: Kind::InvalidCorsCall, - }; - } - } - return ResponseFuture { inner: Kind::PreflightCall { headers: self.make_preflight_header_map(origin, &parts), @@ -748,7 +722,6 @@ pin_project! { PreflightCall { headers: HeaderMap, }, - InvalidCorsCall, } } @@ -774,14 +747,6 @@ where let mut response = Response::new(B::default()); mem::swap(response.headers_mut(), headers); - Poll::Ready(Ok(response)) - } - KindProj::InvalidCorsCall => { - let response = Response::builder() - .status(StatusCode::OK) - .body(B::default()) - .unwrap(); - Poll::Ready(Ok(response)) } } @@ -807,31 +772,3 @@ fn response_origin(allow_origin: &AnyOr, origin: &HeaderValue) -> Header origin.clone() } } - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - - #[test] - fn test_is_valid_request_method() { - let cors = Cors::new(()).allow_methods(vec![Method::GET, Method::POST]); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - - let cors = Cors::new(()); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - - let cors = Cors::new(()).allow_methods(Any); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - - let cors = Cors::new(()).allow_methods(Any); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - } -} From af16d2dcb4df055b6b9f1f8a51fed18e478af860 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Mon, 21 Mar 2022 12:33:11 +0100 Subject: [PATCH 03/12] cors: Consistently apply Vary headers before creating ResponseFuture --- tower-http/src/cors.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tower-http/src/cors.rs b/tower-http/src/cors.rs index f72d9a69..5f335d1e 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors.rs @@ -538,6 +538,8 @@ impl Cors { headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); } + apply_vary_headers(&mut headers); + headers } @@ -688,8 +690,6 @@ where response_origin(self.layer.allow_origin.as_ref().unwrap(), &origin), ); - apply_vary_headers(&mut headers); - ResponseFuture { inner: Kind::CorsCall { future: self.inner.call(req), @@ -742,8 +742,6 @@ where } KindProj::NonCorsCall { future } => future.poll(cx), KindProj::PreflightCall { headers } => { - apply_vary_headers(headers); - let mut response = Response::new(B::default()); mem::swap(response.headers_mut(), headers); From 972d37a9b3750bdb8da7923cd15f2edb4e246156 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Mon, 21 Mar 2022 12:35:26 +0100 Subject: [PATCH 04/12] cors: Inline / simplify apply_vary_headers --- tower-http/src/cors.rs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tower-http/src/cors.rs b/tower-http/src/cors.rs index 5f335d1e..94e4de41 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors.rs @@ -538,7 +538,9 @@ impl Cors { headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); } - apply_vary_headers(&mut headers); + headers.append(header::VARY, header::ORIGIN.into()); + headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); + headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_HEADERS.into()); headers } @@ -751,18 +753,6 @@ where } } -fn apply_vary_headers(headers: &mut http::HeaderMap) { - const VARY_HEADERS: [HeaderName; 3] = [ - header::ORIGIN, - header::ACCESS_CONTROL_REQUEST_METHOD, - header::ACCESS_CONTROL_REQUEST_HEADERS, - ]; - - for h in &VARY_HEADERS { - headers.append(header::VARY, HeaderValue::from_static(h.as_str())); - } -} - fn response_origin(allow_origin: &AnyOr, origin: &HeaderValue) -> HeaderValue { if let AnyOrInner::Any = &allow_origin.0 { WILDCARD From c124cbec914a6ad47c8ba9405f78550f6df3d224 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 25 Mar 2022 21:53:27 +0100 Subject: [PATCH 05/12] cors: Rewrite the CORS service / layer --- tower-http/src/cors/allow_credentials.rs | 86 +++++ tower-http/src/cors/allow_headers.rs | 105 ++++++ tower-http/src/cors/allow_methods.rs | 120 +++++++ tower-http/src/cors/allow_origin.rs | 131 ++++++++ tower-http/src/cors/expose_headers.rs | 86 +++++ tower-http/src/cors/max_age.rs | 68 ++++ tower-http/src/{cors.rs => cors/mod.rs} | 405 ++++++++--------------- 7 files changed, 740 insertions(+), 261 deletions(-) create mode 100644 tower-http/src/cors/allow_credentials.rs create mode 100644 tower-http/src/cors/allow_headers.rs create mode 100644 tower-http/src/cors/allow_methods.rs create mode 100644 tower-http/src/cors/allow_origin.rs create mode 100644 tower-http/src/cors/expose_headers.rs create mode 100644 tower-http/src/cors/max_age.rs rename tower-http/src/{cors.rs => cors/mod.rs} (63%) diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs new file mode 100644 index 00000000..b09bcdb6 --- /dev/null +++ b/tower-http/src/cors/allow_credentials.rs @@ -0,0 +1,86 @@ +use std::{fmt, sync::Arc}; + +use http::{request::Parts as RequestParts, HeaderValue}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header. +/// +/// See [`CorsLayer::allow_credentials`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials +/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials +#[derive(Clone, Default)] +pub struct AllowCredentials(AllowCredentialsInner); + +impl AllowCredentials { + /// Allow credentials for all requests + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn yes() -> Self { + Self(AllowCredentialsInner::Yes) + } + + /// Allow credentials for some requests, based on a given predicate + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowCredentialsInner::Predicate(Arc::new(f))) + } + + pub(super) fn to_header_val( + &self, + origin: &HeaderValue, + parts: &RequestParts, + ) -> Option { + #[allow(clippy::declare_interior_mutable_const)] + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + let allow_creds = match &self.0 { + AllowCredentialsInner::Yes => true, + AllowCredentialsInner::No => false, + AllowCredentialsInner::Predicate(c) => c(origin, parts), + }; + + allow_creds.then(|| TRUE) + } +} + +impl From for AllowCredentials { + fn from(v: bool) -> Self { + match v { + true => Self(AllowCredentialsInner::Yes), + false => Self(AllowCredentialsInner::No), + } + } +} + +impl fmt::Debug for AllowCredentials { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowCredentialsInner::Yes => f.debug_tuple("Yes").finish(), + AllowCredentialsInner::No => f.debug_tuple("No").finish(), + AllowCredentialsInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowCredentialsInner { + Yes, + No, + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowCredentialsInner { + fn default() -> Self { + Self::No + } +} diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs new file mode 100644 index 00000000..3bbd507a --- /dev/null +++ b/tower-http/src/cors/allow_headers.rs @@ -0,0 +1,105 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName}, + request::Parts as RequestParts, + HeaderValue, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Headers`][mdn] header. +/// +/// See [`CorsLayer::allow_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers +/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers +#[derive(Clone, Default)] +pub struct AllowHeaders(AllowHeadersInner); + +impl AllowHeaders { + /// Allow any headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn any() -> Self { + Self(AllowHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple allowed headers + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(AllowHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + /// Allow any headers, by mirroring the preflight [`Access-Control-Request-Headers`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers + pub fn mirror_request() -> Self { + Self(AllowHeadersInner::MirrorRequest) + } + + pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { + match &self.0 { + AllowHeadersInner::Const(v) => v.clone(), + AllowHeadersInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_HEADERS) + .cloned(), + } + } +} + +impl fmt::Debug for AllowHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowHeadersInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From for AllowHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<[HeaderName; N]> for AllowHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowHeaders { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowHeadersInner { + Const(Option), + MirrorRequest, +} + +impl Default for AllowHeadersInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs new file mode 100644 index 00000000..0eacadda --- /dev/null +++ b/tower-http/src/cors/allow_methods.rs @@ -0,0 +1,120 @@ +use std::{array, fmt}; + +use http::{header, request::Parts as RequestParts, HeaderValue, Method}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Methods`][mdn] header. +/// +/// See [`CorsLayer::allow_methods`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods +/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods +#[derive(Clone, Default)] +pub struct AllowMethods(AllowMethodsInner); + +impl AllowMethods { + /// Allow any method by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn any() -> Self { + Self(AllowMethodsInner::Const(Some(WILDCARD))) + } + + /// Set a single allowed method + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn exact(method: Method) -> Self { + Self(AllowMethodsInner::Const(Some( + HeaderValue::from_str(method.as_str()).unwrap(), + ))) + } + + /// Set multiple allowed methods + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn list(methods: I) -> Self + where + I: IntoIterator, + { + Self(AllowMethodsInner::Const(separated_by_commas( + methods + .into_iter() + .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), + ))) + } + + /// Allow any method, by mirroring the preflight [`Access-Control-Request-Method`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Method + pub fn mirror_request() -> Self { + Self(AllowMethodsInner::MirrorRequest) + } + + pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { + match &self.0 { + AllowMethodsInner::Const(v) => v.clone(), + AllowMethodsInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_METHOD) + .cloned(), + } + } +} + +impl fmt::Debug for AllowMethods { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowMethodsInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowMethodsInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From for AllowMethods { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From for AllowMethods { + fn from(method: Method) -> Self { + Self::exact(method) + } +} + +impl From<[Method; N]> for AllowMethods { + fn from(arr: [Method; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowMethods { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowMethodsInner { + Const(Option), + MirrorRequest, +} + +impl Default for AllowMethodsInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs new file mode 100644 index 00000000..5f694291 --- /dev/null +++ b/tower-http/src/cors/allow_origin.rs @@ -0,0 +1,131 @@ +use std::{array, fmt, sync::Arc}; + +use http::{request::Parts as RequestParts, HeaderValue}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header. +/// +/// See [`CorsLayer::allow_origin`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin +/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin +#[derive(Clone, Default)] +pub struct AllowOrigin(OriginInner); + +impl AllowOrigin { + /// Allow any origin by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn any() -> Self { + Self(OriginInner::Const(Some(WILDCARD))) + } + + /// Set a single allowed origin + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn exact(origin: HeaderValue) -> Self { + Self(OriginInner::Const(Some(origin))) + } + + /// Set multiple allowed origins + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn list(origins: I) -> Self + where + I: IntoIterator, + { + Self(OriginInner::Const(separated_by_commas( + origins.into_iter().map(Into::into), + ))) + } + + /// Set the allowed origins from a predicate + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(OriginInner::Predicate(Arc::new(f))) + } + + /// Allow any origin, by mirroring the request origin + /// + /// This is equivalent to + /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate]. + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn mirror_request() -> Self { + Self::predicate(|_, _| true) + } + + pub(super) fn to_header_val( + &self, + origin: &HeaderValue, + parts: &RequestParts, + ) -> Option { + match &self.0 { + OriginInner::Const(v) => v.clone(), + OriginInner::Predicate(c) => c(origin, parts).then(|| origin.to_owned()), + } + } +} + +impl fmt::Debug for AllowOrigin { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +impl From for AllowOrigin { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From for AllowOrigin { + fn from(val: HeaderValue) -> Self { + Self::exact(val) + } +} + +impl From<[HeaderValue; N]> for AllowOrigin { + fn from(arr: [HeaderValue; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowOrigin { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum OriginInner { + Const(Option), + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for OriginInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/expose_headers.rs b/tower-http/src/cors/expose_headers.rs new file mode 100644 index 00000000..91327524 --- /dev/null +++ b/tower-http/src/cors/expose_headers.rs @@ -0,0 +1,86 @@ +use std::{array, fmt}; + +use http::{ + header::{HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Expose-Headers`][mdn] header. +/// +/// See [`CorsLayer::expose_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers +/// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers +#[derive(Clone, Default)] +pub struct ExposeHeaders(ExposeHeadersInner); + +impl ExposeHeaders { + /// Expose any / all headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn any() -> Self { + Self(ExposeHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple exposed header names + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(ExposeHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + pub(super) fn to_header_val(&self, _parts: &RequestParts) -> Option { + match &self.0 { + ExposeHeadersInner::Const(v) => v.clone(), + } + } +} + +impl fmt::Debug for ExposeHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + ExposeHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + } + } +} + +impl From for ExposeHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<[HeaderName; N]> for ExposeHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for ExposeHeaders { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum ExposeHeadersInner { + Const(Option), +} + +impl Default for ExposeHeadersInner { + fn default() -> Self { + ExposeHeadersInner::Const(None) + } +} diff --git a/tower-http/src/cors/max_age.rs b/tower-http/src/cors/max_age.rs new file mode 100644 index 00000000..5b3426f4 --- /dev/null +++ b/tower-http/src/cors/max_age.rs @@ -0,0 +1,68 @@ +use std::{fmt, sync::Arc, time::Duration}; + +use http::{request::Parts as RequestParts, HeaderValue}; + +/// Holds configuration for how to set the [`Access-Control-Max-Age`][mdn] header. +/// +/// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age +#[derive(Clone, Default)] +pub struct MaxAge(MaxAgeInner); + +impl MaxAge { + /// Set a static max-age value + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn exact(max_age: Duration) -> Self { + Self(MaxAgeInner::Exact(Some(max_age.as_secs().into()))) + } + + /// Set the max-age based on the preflight request parts + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn dynamic(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> Duration + Send + Sync + 'static, + { + Self(MaxAgeInner::Fn(Arc::new(f))) + } + + pub(super) fn to_header_val( + &self, + origin: &HeaderValue, + parts: &RequestParts, + ) -> Option { + match &self.0 { + MaxAgeInner::Exact(v) => v.clone(), + MaxAgeInner::Fn(c) => Some(c(origin, parts).as_secs().into()), + } + } +} + +impl fmt::Debug for MaxAge { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + MaxAgeInner::Exact(inner) => f.debug_tuple("Exact").field(inner).finish(), + MaxAgeInner::Fn(_) => f.debug_tuple("Fn").finish(), + } + } +} + +impl From for MaxAge { + fn from(max_age: Duration) -> Self { + Self::exact(max_age) + } +} + +#[derive(Clone)] +enum MaxAgeInner { + Exact(Option), + Fn(Arc Fn(&'a HeaderValue, &'a RequestParts) -> Duration + Send + Sync + 'static>), +} + +impl Default for MaxAgeInner { + fn default() -> Self { + Self::Exact(None) + } +} diff --git a/tower-http/src/cors.rs b/tower-http/src/cors/mod.rs similarity index 63% rename from tower-http/src/cors.rs rename to tower-http/src/cors/mod.rs index 94e4de41..b88f3721 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors/mod.rs @@ -17,7 +17,7 @@ //! # async fn main() -> Result<(), Box> { //! let cors = CorsLayer::new() //! // allow `GET` and `POST` when accessing the resource -//! .allow_methods(vec![Method::GET, Method::POST]) +//! .allow_methods([Method::GET, Method::POST]) //! // allow requests from any origin //! .allow_origin(Any); //! @@ -50,24 +50,30 @@ use bytes::{BufMut, BytesMut}; use futures_core::ready; -use http::{ - header::{self, HeaderName, HeaderValue}, - request::Parts, - HeaderMap, Method, Request, Response, -}; +use http::{header, HeaderMap, HeaderValue, Method, Request, Response}; use pin_project_lite::pin_project; use std::{ - fmt, future::Future, mem, pin::Pin, - sync::Arc, task::{Context, Poll}, time::Duration, }; use tower_layer::Layer; use tower_service::Service; +mod allow_credentials; +mod allow_headers; +mod allow_methods; +mod allow_origin; +mod expose_headers; +mod max_age; + +pub use self::{ + allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, + allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, +}; + /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. /// /// See the [module docs](crate::cors) for an example. @@ -75,12 +81,12 @@ use tower_service::Service; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] pub struct CorsLayer { - allow_credentials: bool, - allow_headers: Option, - allow_methods: Option, - allow_origin: Option>, - expose_headers: Option, - max_age: Option, + allow_credentials: AllowCredentials, + allow_headers: AllowHeaders, + allow_methods: AllowMethods, + allow_origin: AllowOrigin, + expose_headers: ExposeHeaders, + max_age: MaxAge, } #[allow(clippy::declare_interior_mutable_const)] @@ -93,12 +99,12 @@ impl CorsLayer { /// customize the behavior. pub fn new() -> Self { Self { - allow_credentials: false, - allow_headers: None, - allow_methods: None, - allow_origin: None, - expose_headers: None, - max_age: None, + allow_credentials: Default::default(), + allow_headers: Default::default(), + allow_methods: Default::default(), + allow_origin: Default::default(), + expose_headers: Default::default(), + max_age: Default::default(), } } @@ -131,8 +137,11 @@ impl CorsLayer { /// ``` /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - pub fn allow_credentials(mut self, allow_credentials: bool) -> Self { - self.allow_credentials = allow_credentials; + pub fn allow_credentials(mut self, allow_credentials: T) -> Self + where + T: Into, + { + self.allow_credentials = allow_credentials.into(); self } @@ -142,7 +151,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::header::{AUTHORIZATION, ACCEPT}; /// - /// let layer = CorsLayer::new().allow_headers(vec![AUTHORIZATION, ACCEPT]); + /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]); /// ``` /// /// All headers can be allowed with @@ -160,22 +169,19 @@ impl CorsLayer { /// `Access-Control-Request-Headers`. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - pub fn allow_headers(mut self, headers: I) -> Self + pub fn allow_headers(mut self, headers: T) -> Self where - I: Into>>, + T: Into, { - self.allow_headers = match headers.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)), - }; + self.allow_headers = headers.into(); self } /// Set the value of the [`Access-Control-Max-Age`][mdn] header. /// /// ``` - /// use tower_http::cors::CorsLayer; /// use std::time::Duration; + /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10); /// ``` @@ -187,9 +193,34 @@ impl CorsLayer { /// precedence when the Access-Control-Max-Age is greater. For more details /// see [mdn]. /// + /// If you need more flexibility, you can use supply a function which can + /// dynamically decide the max-age based on the origin and other parts of + /// each preflight request: + /// + /// ``` + /// # struct MyServerConfig { cors_max_age: Duration } + /// use std::time::Duration; + /// + /// use http::{request::Parts as RequestParts, HeaderValue}; + /// use tower_http::cors::{CorsLayer, MaxAge}; + /// + /// let layer = CorsLayer::new().max_age(MaxAge::dynamic( + /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration { + /// // Let's say you want to be able to reload your config at + /// // runtime and have another middleware that always inserts + /// // the current config into the request extensions + /// let config = parts.extensions.get::().unwrap(); + /// config.cors_max_age + /// }, + /// )); + /// ``` + /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - pub fn max_age(mut self, max_age: Duration) -> Self { - self.max_age = Some(max_age.as_secs().into()); + pub fn max_age(mut self, max_age: T) -> Self + where + T: Into, + { + self.max_age = max_age.into(); self } @@ -199,7 +230,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::Method; /// - /// let layer = CorsLayer::new().allow_methods(vec![Method::GET, Method::POST]); + /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]); /// ``` /// /// All methods can be allowed with @@ -216,40 +247,34 @@ impl CorsLayer { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(mut self, methods: T) -> Self where - T: Into>>, + T: Into, { - self.allow_methods = match methods.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(methods) => separated_by_commas( - methods - .into_iter() - .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), - ), - }; + self.allow_methods = methods.into(); self } /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; + /// use http::HeaderValue; + /// use tower_http::cors::CorsLayer; /// - /// let layer = CorsLayer::new().allow_origin(Origin::exact( - /// "http://example.com".parse().unwrap(), - /// )); + /// let layer = CorsLayer::new().allow_origin( + /// "http://example.com".parse::().unwrap(), + /// ); /// ``` /// /// Multiple origins can be allowed with /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; + /// use tower_http::cors::CorsLayer; /// - /// let origins = vec![ + /// let origins = [ /// "http://example.com".parse().unwrap(), /// "http://api.example.com".parse().unwrap(), /// ]; /// - /// let layer = CorsLayer::new().allow_origin(Origin::list(origins)); + /// let layer = CorsLayer::new().allow_origin(origins); /// ``` /// /// All origins can be allowed with @@ -263,14 +288,14 @@ impl CorsLayer { /// You can also use a closure /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; - /// use http::{HeaderValue, request::Parts}; + /// use tower_http::cors::{CorsLayer, AllowOrigin}; + /// use http::{request::Parts as RequestParts, HeaderValue}; /// - /// let layer = CorsLayer::new().allow_origin( - /// Origin::predicate(|origin: &HeaderValue, _request_head: &Parts| { + /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate( + /// |origin: &HeaderValue, _request_parts: &RequestParts| { /// origin.as_bytes().ends_with(b".rust-lang.org") - /// }) - /// ); + /// }, + /// )); /// ``` /// /// Note that multiple calls to this method will override any previous @@ -279,9 +304,9 @@ impl CorsLayer { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(mut self, origin: T) -> Self where - T: Into>, + T: Into, { - self.allow_origin = Some(origin.into()); + self.allow_origin = origin.into(); self } @@ -291,7 +316,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::header::CONTENT_ENCODING; /// - /// let layer = CorsLayer::new().expose_headers(vec![CONTENT_ENCODING]); + /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]); /// ``` /// /// All headers can be allowed with @@ -306,14 +331,11 @@ impl CorsLayer { /// calls. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers - pub fn expose_headers(mut self, headers: I) -> Self + pub fn expose_headers(mut self, headers: T) -> Self where - I: Into>>, + T: Into, { - self.expose_headers = match headers.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)), - }; + self.expose_headers = headers.into(); self } } @@ -330,48 +352,6 @@ pub fn any() -> Any { Any } -/// Used to make methods like [`CorsLayer::allow_methods`] more convenient to call. -/// -/// You shouldn't have to use this type directly. -#[derive(Debug, Clone, Copy)] -pub struct AnyOr(AnyOrInner); - -#[derive(Debug, Clone, Copy)] -enum AnyOrInner { - Any, - Value(T), -} - -impl From for AnyOr { - fn from(origin: Origin) -> Self { - AnyOr(AnyOrInner::Value(origin)) - } -} - -impl From for AnyOr { - fn from(_: Any) -> Self { - AnyOr(AnyOrInner::Any) - } -} - -impl From for AnyOr> -where - I: IntoIterator, -{ - fn from(methods: I) -> Self { - AnyOr(AnyOrInner::Value(methods.into_iter().collect())) - } -} - -impl From for AnyOr> -where - I: IntoIterator, -{ - fn from(headers: I) -> Self { - AnyOr(AnyOrInner::Value(headers.into_iter().collect())) - } -} - fn separated_by_commas(mut iter: I) -> Option where I: Iterator, @@ -455,7 +435,10 @@ impl Cors { /// See [`CorsLayer::allow_credentials`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - pub fn allow_credentials(self, allow_credentials: bool) -> Self { + pub fn allow_credentials(self, allow_credentials: T) -> Self + where + T: Into, + { self.map_layer(|layer| layer.allow_credentials(allow_credentials)) } @@ -464,9 +447,9 @@ impl Cors { /// See [`CorsLayer::allow_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - pub fn allow_headers(self, headers: I) -> Self + pub fn allow_headers(self, headers: T) -> Self where - I: Into>>, + T: Into, { self.map_layer(|layer| layer.allow_headers(headers)) } @@ -476,7 +459,10 @@ impl Cors { /// See [`CorsLayer::max_age`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - pub fn max_age(self, max_age: Duration) -> Self { + pub fn max_age(self, max_age: T) -> Self + where + T: Into, + { self.map_layer(|layer| layer.max_age(max_age)) } @@ -487,7 +473,7 @@ impl Cors { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(self, methods: T) -> Self where - T: Into>>, + T: Into, { self.map_layer(|layer| layer.allow_methods(methods)) } @@ -499,7 +485,7 @@ impl Cors { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(self, origin: T) -> Self where - T: Into>, + T: Into, { self.map_layer(|layer| layer.allow_origin(origin)) } @@ -509,9 +495,9 @@ impl Cors { /// See [`CorsLayer::expose_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers - pub fn expose_headers(self, headers: I) -> Self + pub fn expose_headers(self, headers: T) -> Self where - I: Into>>, + T: Into, { self.map_layer(|layer| layer.expose_headers(headers)) } @@ -523,127 +509,6 @@ impl Cors { self.layer = f(self.layer); self } - - fn make_response_header_map(&self) -> HeaderMap { - #[allow(clippy::declare_interior_mutable_const)] - const TRUE: HeaderValue = HeaderValue::from_static("true"); - - let mut headers = HeaderMap::new(); - - if self.layer.allow_credentials { - headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE); - } - - if let Some(expose_headers) = self.layer.expose_headers.clone() { - headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); - } - - headers.append(header::VARY, header::ORIGIN.into()); - headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); - headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_HEADERS.into()); - - headers - } - - fn make_preflight_header_map(&self, origin: HeaderValue, parts: &Parts) -> HeaderMap { - let mut headers = self.make_response_header_map(); - - if let Some(allow_origin) = &self.layer.allow_origin { - if let Some(origin) = allow_origin.to_header_val(origin, parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - } - } - - if let Some(allow_methods) = &self.layer.allow_methods { - headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods.clone()); - } - - if let Some(allow_headers) = &self.layer.allow_headers { - headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers.clone()); - } - - if let Some(max_age) = self.layer.max_age.clone() { - headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age); - } - - headers - } -} - -/// Represents a [`Access-Control-Allow-Origin`][mdn] header. -/// -/// See [`CorsLayer::allow_origin`] for more details. -/// -/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin -#[derive(Clone)] -pub struct Origin(OriginInner); - -impl Origin { - /// Set a single allow origin target - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn exact(origin: HeaderValue) -> Self { - Self(OriginInner::Const(Some(origin))) - } - - /// Set multiple allow origin targets - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn list(origins: I) -> Self - where - I: IntoIterator, - { - Self(OriginInner::Const(separated_by_commas( - origins.into_iter().map(Into::into), - ))) - } - - /// Set the allowed origins from a predicate - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn predicate(f: F) -> Self - where - F: Fn(&HeaderValue, &Parts) -> bool + Send + Sync + 'static, - { - Self(OriginInner::Closure(Arc::new(f))) - } - - fn to_header_val(&self, origin: HeaderValue, parts: &Parts) -> Option { - match &self.0 { - OriginInner::Const(v) => v.clone(), - OriginInner::Closure(c) => { - if c(&origin, parts) { - Some(origin) - } else { - None - } - } - } - } -} - -impl AnyOr { - fn to_header_val(&self, origin: HeaderValue, parts: &Parts) -> Option { - match &self.0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(o) => o.to_header_val(origin, parts), - } - } -} - -impl fmt::Debug for Origin { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), - OriginInner::Closure(_) => f.debug_tuple("Closure").finish(), - } - } -} - -#[derive(Clone)] -enum OriginInner { - Const(Option), - Closure(Arc Fn(&'a HeaderValue, &'a Parts) -> bool + Send + Sync + 'static>), } impl Service> for Cors @@ -662,10 +527,11 @@ where fn call(&mut self, req: Request) -> Self::Future { let origin = req.headers().get(&header::ORIGIN).cloned(); + // Only requests with an origin can be considered CORS requests: + // https://fetch.spec.whatwg.org/#http-requests let origin = if let Some(origin) = origin { origin } else { - // This is not a CORS request if there is no Origin header return ResponseFuture { inner: Kind::NonCorsCall { future: self.inner.call(req), @@ -675,28 +541,53 @@ where let (parts, body) = req.into_parts(); + let mut headers = HeaderMap::new(); + + // These headers are applied to both preflight and subsequent regular CORS requests: + // https://fetch.spec.whatwg.org/#http-responses + if let Some(value) = self.layer.allow_origin.to_header_val(&origin, &parts) { + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, value); + } + + if let Some(value) = self.layer.allow_credentials.to_header_val(&origin, &parts) { + headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, value); + } + + headers.append(header::VARY, header::ORIGIN.into()); + headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); + headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_HEADERS.into()); + // Return results immediately upon preflight request if parts.method == Method::OPTIONS { - return ResponseFuture { - inner: Kind::PreflightCall { - headers: self.make_preflight_header_map(origin, &parts), - }, - }; - } + // These headers are applied only to preflight requests + if let Some(value) = self.layer.allow_methods.to_header_val(&parts) { + headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, value); + } - let req = Request::from_parts(parts, body); + if let Some(value) = self.layer.allow_headers.to_header_val(&parts) { + headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, value); + } - let mut headers = self.make_response_header_map(); - headers.insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - response_origin(self.layer.allow_origin.as_ref().unwrap(), &origin), - ); + if let Some(value) = self.layer.max_age.to_header_val(&origin, &parts) { + headers.insert(header::ACCESS_CONTROL_MAX_AGE, value); + } - ResponseFuture { - inner: Kind::CorsCall { - future: self.inner.call(req), - headers, - }, + ResponseFuture { + inner: Kind::PreflightCall { headers }, + } + } else { + // This header is applied only to non-preflight requests + if let Some(value) = self.layer.expose_headers.to_header_val(&parts) { + headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, value); + } + + let req = Request::from_parts(parts, body); + ResponseFuture { + inner: Kind::CorsCall { + future: self.inner.call(req), + headers, + }, + } } } } @@ -752,11 +643,3 @@ where } } } - -fn response_origin(allow_origin: &AnyOr, origin: &HeaderValue) -> HeaderValue { - if let AnyOrInner::Any = &allow_origin.0 { - WILDCARD - } else { - origin.clone() - } -} From 70bdf061985b9f5b9ef47592d52035f5c96019fd Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 26 Mar 2022 14:58:57 +0100 Subject: [PATCH 06/12] cors: Add #[must_use] to CORS types --- tower-http/src/cors/allow_credentials.rs | 1 + tower-http/src/cors/allow_headers.rs | 1 + tower-http/src/cors/allow_methods.rs | 1 + tower-http/src/cors/allow_origin.rs | 1 + tower-http/src/cors/expose_headers.rs | 1 + tower-http/src/cors/max_age.rs | 1 + tower-http/src/cors/mod.rs | 3 +++ 7 files changed, 9 insertions(+) diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index b09bcdb6..72bcbb21 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -9,6 +9,7 @@ use http::{request::Parts as RequestParts, HeaderValue}; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials #[derive(Clone, Default)] +#[must_use] pub struct AllowCredentials(AllowCredentialsInner); impl AllowCredentials { diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs index 3bbd507a..8b9d9625 100644 --- a/tower-http/src/cors/allow_headers.rs +++ b/tower-http/src/cors/allow_headers.rs @@ -15,6 +15,7 @@ use super::{separated_by_commas, Any, WILDCARD}; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers #[derive(Clone, Default)] +#[must_use] pub struct AllowHeaders(AllowHeadersInner); impl AllowHeaders { diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs index 0eacadda..8a2df18d 100644 --- a/tower-http/src/cors/allow_methods.rs +++ b/tower-http/src/cors/allow_methods.rs @@ -11,6 +11,7 @@ use super::{separated_by_commas, Any, WILDCARD}; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods #[derive(Clone, Default)] +#[must_use] pub struct AllowMethods(AllowMethodsInner); impl AllowMethods { diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs index 5f694291..4df52f71 100644 --- a/tower-http/src/cors/allow_origin.rs +++ b/tower-http/src/cors/allow_origin.rs @@ -11,6 +11,7 @@ use super::{separated_by_commas, Any, WILDCARD}; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin #[derive(Clone, Default)] +#[must_use] pub struct AllowOrigin(OriginInner); impl AllowOrigin { diff --git a/tower-http/src/cors/expose_headers.rs b/tower-http/src/cors/expose_headers.rs index 91327524..00751bc9 100644 --- a/tower-http/src/cors/expose_headers.rs +++ b/tower-http/src/cors/expose_headers.rs @@ -14,6 +14,7 @@ use super::{separated_by_commas, Any, WILDCARD}; /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers #[derive(Clone, Default)] +#[must_use] pub struct ExposeHeaders(ExposeHeadersInner); impl ExposeHeaders { diff --git a/tower-http/src/cors/max_age.rs b/tower-http/src/cors/max_age.rs index 5b3426f4..8410eee9 100644 --- a/tower-http/src/cors/max_age.rs +++ b/tower-http/src/cors/max_age.rs @@ -8,6 +8,7 @@ use http::{request::Parts as RequestParts, HeaderValue}; /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age #[derive(Clone, Default)] +#[must_use] pub struct MaxAge(MaxAgeInner); impl MaxAge { diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index b88f3721..43022d45 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -80,6 +80,7 @@ pub use self::{ /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] +#[must_use] pub struct CorsLayer { allow_credentials: AllowCredentials, allow_headers: AllowHeaders, @@ -343,6 +344,7 @@ impl CorsLayer { /// Represents a wildcard value (`*`) used with some CORS headers such as /// [`CorsLayer::allow_methods`]. #[derive(Debug, Clone, Copy)] +#[must_use] pub struct Any; /// Represents a wildcard value (`*`) used with some CORS headers such as @@ -394,6 +396,7 @@ impl Layer for CorsLayer { /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] +#[must_use] pub struct Cors { inner: S, layer: CorsLayer, From 9f6882fa599e2ec251850d032e4867f9ff2c93cc Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 26 Mar 2022 15:16:14 +0100 Subject: [PATCH 07/12] cors: Revamp constructors --- tower-http/src/cors/mod.rs | 47 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 43022d45..f212fbca 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -57,7 +57,6 @@ use std::{ mem, pin::Pin, task::{Context, Poll}, - time::Duration, }; use tower_layer::Layer; use tower_service::Service; @@ -96,8 +95,11 @@ const WILDCARD: HeaderValue = HeaderValue::from_static("*"); impl CorsLayer { /// Create a new `CorsLayer`. /// - /// This creates a restrictive configuration. Use the builder methods to - /// customize the behavior. + /// No headers are sent by default. Use the builder methods to customize + /// the behavior. + /// + /// You need to set at least an allowed origin for browsers to make + /// successful cross-origin requests to your service. pub fn new() -> Self { Self { allow_credentials: Default::default(), @@ -109,24 +111,34 @@ impl CorsLayer { } } - /// A very permissive configuration suitable for development: + /// A permissive configuration: /// - /// - Credentials allowed. /// - All request headers allowed. /// - All methods allowed. /// - All origins allowed. /// - All headers exposed. - /// - Max age set to 1 hour. - /// - /// Note this is not recommended for production use. pub fn permissive() -> Self { Self::new() - .allow_credentials(true) .allow_headers(Any) .allow_methods(Any) .allow_origin(Any) .expose_headers(Any) - .max_age(Duration::from_secs(60 * 60)) + } + + /// A very permissive configuration: + /// + /// - **Credentials allowed.** + /// - The method received in `Access-Control-Request-Method` is sent back + /// as an allowed method. + /// - The origin of the preflight request is sent back as an allowed origin. + /// - The header names received in `Access-Control-Request-Headers` are sent + /// back as allowed headers. + /// - No headers are currently exposed, but this may change in the future. + pub fn very_permissive() -> Self { + Self::new() + .allow_headers(AllowHeaders::mirror_request()) + .allow_methods(AllowMethods::mirror_request()) + .allow_origin(AllowOrigin::mirror_request()) } /// Set the [`Access-Control-Allow-Credentials`][mdn] header. @@ -405,8 +417,7 @@ pub struct Cors { impl Cors { /// Create a new `Cors`. /// - /// This creates a restrictive configuration. Use the builder methods to - /// customize the behavior. + /// See [`CorsLayer::new`] for more details. pub fn new(inner: S) -> Self { Self { inner, @@ -414,7 +425,7 @@ impl Cors { } } - /// A very permissive configuration suitable for development. + /// A permissive configuration. /// /// See [`CorsLayer::permissive`] for more details. pub fn permissive(inner: S) -> Self { @@ -424,6 +435,16 @@ impl Cors { } } + /// A very permissive configuration. + /// + /// See [`CorsLayer::very_permissive`] for more details. + pub fn very_permissive(inner: S) -> Self { + Self { + inner, + layer: CorsLayer::very_permissive(), + } + } + define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware. From 28a071e2166e495f5a4e64cad666bd814626edf8 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 26 Mar 2022 15:26:34 +0100 Subject: [PATCH 08/12] cors: Panic if configuration is invalid --- tower-http/src/cors/allow_credentials.rs | 4 +++ tower-http/src/cors/allow_headers.rs | 5 ++++ tower-http/src/cors/allow_methods.rs | 5 ++++ tower-http/src/cors/allow_origin.rs | 5 ++++ tower-http/src/cors/expose_headers.rs | 5 ++++ tower-http/src/cors/mod.rs | 31 ++++++++++++++++++++++++ 6 files changed, 55 insertions(+) diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index 72bcbb21..d7ea770d 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -34,6 +34,10 @@ impl AllowCredentials { Self(AllowCredentialsInner::Predicate(Arc::new(f))) } + pub(super) fn is_true(&self) -> bool { + matches!(&self.0, AllowCredentialsInner::Yes) + } + pub(super) fn to_header_val( &self, origin: &HeaderValue, diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs index 8b9d9625..b6c70061 100644 --- a/tower-http/src/cors/allow_headers.rs +++ b/tower-http/src/cors/allow_headers.rs @@ -54,6 +54,11 @@ impl AllowHeaders { Self(AllowHeadersInner::MirrorRequest) } + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) + } + pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { match &self.0 { AllowHeadersInner::Const(v) => v.clone(), diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs index 8a2df18d..28392888 100644 --- a/tower-http/src/cors/allow_methods.rs +++ b/tower-http/src/cors/allow_methods.rs @@ -63,6 +63,11 @@ impl AllowMethods { Self(AllowMethodsInner::MirrorRequest) } + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) + } + pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { match &self.0 { AllowMethodsInner::Const(v) => v.clone(), diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs index 4df52f71..509eb212 100644 --- a/tower-http/src/cors/allow_origin.rs +++ b/tower-http/src/cors/allow_origin.rs @@ -71,6 +71,11 @@ impl AllowOrigin { Self::predicate(|_, _| true) } + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, OriginInner::Const(Some(v)) if v == WILDCARD) + } + pub(super) fn to_header_val( &self, origin: &HeaderValue, diff --git a/tower-http/src/cors/expose_headers.rs b/tower-http/src/cors/expose_headers.rs index 00751bc9..6ea05b35 100644 --- a/tower-http/src/cors/expose_headers.rs +++ b/tower-http/src/cors/expose_headers.rs @@ -41,6 +41,11 @@ impl ExposeHeaders { ))) } + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, ExposeHeadersInner::Const(Some(v)) if v == WILDCARD) + } + pub(super) fn to_header_val(&self, _parts: &RequestParts) -> Option { match &self.0 { ExposeHeadersInner::Const(v) => v.clone(), diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index f212fbca..595d0591 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -395,6 +395,8 @@ impl Layer for CorsLayer { type Service = Cors; fn layer(&self, inner: S) -> Self::Service { + ensure_usable_cors_rules(self); + Cors { inner, layer: self.clone(), @@ -545,6 +547,7 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + ensure_usable_cors_rules(&self.layer); self.inner.poll_ready(cx) } @@ -667,3 +670,31 @@ where } } } + +fn ensure_usable_cors_rules(layer: &CorsLayer) { + if layer.allow_credentials.is_true() { + assert!( + !layer.allow_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Headers: *`" + ); + + assert!( + !layer.allow_methods.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Methods: *`" + ); + + assert!( + !layer.allow_origin.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Origin: *`" + ); + + assert!( + !layer.expose_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Expose-Headers: *`" + ); + } +} From 953f8e3204db1b05e507ec8681339f7b7216844a Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 7 Apr 2022 18:52:26 +0200 Subject: [PATCH 09/12] cors: Move more response header bits into sub-modules --- tower-http/src/cors/allow_credentials.rs | 11 ++++++---- tower-http/src/cors/allow_headers.rs | 17 ++++++++------- tower-http/src/cors/allow_methods.rs | 20 ++++++++++++------ tower-http/src/cors/allow_origin.rs | 19 +++++++++++------ tower-http/src/cors/expose_headers.rs | 12 ++++++----- tower-http/src/cors/max_age.rs | 19 +++++++++++------ tower-http/src/cors/mod.rs | 27 ++++++------------------ 7 files changed, 66 insertions(+), 59 deletions(-) diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index d7ea770d..a0c0a6e5 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -1,6 +1,9 @@ use std::{fmt, sync::Arc}; -use http::{request::Parts as RequestParts, HeaderValue}; +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; /// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header. /// @@ -38,11 +41,11 @@ impl AllowCredentials { matches!(&self.0, AllowCredentialsInner::Yes) } - pub(super) fn to_header_val( + pub(super) fn to_header( &self, origin: &HeaderValue, parts: &RequestParts, - ) -> Option { + ) -> Option<(HeaderName, HeaderValue)> { #[allow(clippy::declare_interior_mutable_const)] const TRUE: HeaderValue = HeaderValue::from_static("true"); @@ -52,7 +55,7 @@ impl AllowCredentials { AllowCredentialsInner::Predicate(c) => c(origin, parts), }; - allow_creds.then(|| TRUE) + allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) } } diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs index b6c70061..06c19928 100644 --- a/tower-http/src/cors/allow_headers.rs +++ b/tower-http/src/cors/allow_headers.rs @@ -1,9 +1,8 @@ use std::{array, fmt}; use http::{ - header::{self, HeaderName}, + header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, - HeaderValue, }; use super::{separated_by_commas, Any, WILDCARD}; @@ -59,14 +58,16 @@ impl AllowHeaders { matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) } - pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { - match &self.0 { - AllowHeadersInner::Const(v) => v.clone(), + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_headers = match &self.0 { + AllowHeadersInner::Const(v) => v.clone()?, AllowHeadersInner::MirrorRequest => parts .headers - .get(header::ACCESS_CONTROL_REQUEST_HEADERS) - .cloned(), - } + .get(header::ACCESS_CONTROL_REQUEST_HEADERS)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers)) } } diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs index 28392888..df1a3cbd 100644 --- a/tower-http/src/cors/allow_methods.rs +++ b/tower-http/src/cors/allow_methods.rs @@ -1,6 +1,10 @@ use std::{array, fmt}; -use http::{header, request::Parts as RequestParts, HeaderValue, Method}; +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, + Method, +}; use super::{separated_by_commas, Any, WILDCARD}; @@ -68,14 +72,16 @@ impl AllowMethods { matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) } - pub(super) fn to_header_val(&self, parts: &RequestParts) -> Option { - match &self.0 { - AllowMethodsInner::Const(v) => v.clone(), + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_methods = match &self.0 { + AllowMethodsInner::Const(v) => v.clone()?, AllowMethodsInner::MirrorRequest => parts .headers - .get(header::ACCESS_CONTROL_REQUEST_METHOD) - .cloned(), - } + .get(header::ACCESS_CONTROL_REQUEST_METHOD)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods)) } } diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs index 509eb212..0b27b657 100644 --- a/tower-http/src/cors/allow_origin.rs +++ b/tower-http/src/cors/allow_origin.rs @@ -1,6 +1,9 @@ use std::{array, fmt, sync::Arc}; -use http::{request::Parts as RequestParts, HeaderValue}; +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; use super::{separated_by_commas, Any, WILDCARD}; @@ -76,15 +79,17 @@ impl AllowOrigin { matches!(&self.0, OriginInner::Const(Some(v)) if v == WILDCARD) } - pub(super) fn to_header_val( + pub(super) fn to_header( &self, origin: &HeaderValue, parts: &RequestParts, - ) -> Option { - match &self.0 { - OriginInner::Const(v) => v.clone(), - OriginInner::Predicate(c) => c(origin, parts).then(|| origin.to_owned()), - } + ) -> Option<(HeaderName, HeaderValue)> { + let allow_origin = match &self.0 { + OriginInner::Const(v) => v.clone()?, + OriginInner::Predicate(c) => c(origin, parts).then(|| origin.to_owned())?, + }; + + Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin)) } } diff --git a/tower-http/src/cors/expose_headers.rs b/tower-http/src/cors/expose_headers.rs index 6ea05b35..2b1a2267 100644 --- a/tower-http/src/cors/expose_headers.rs +++ b/tower-http/src/cors/expose_headers.rs @@ -1,7 +1,7 @@ use std::{array, fmt}; use http::{ - header::{HeaderName, HeaderValue}, + header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; @@ -46,10 +46,12 @@ impl ExposeHeaders { matches!(&self.0, ExposeHeadersInner::Const(Some(v)) if v == WILDCARD) } - pub(super) fn to_header_val(&self, _parts: &RequestParts) -> Option { - match &self.0 { - ExposeHeadersInner::Const(v) => v.clone(), - } + pub(super) fn to_header(&self, _parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let expose_headers = match &self.0 { + ExposeHeadersInner::Const(v) => v.clone()?, + }; + + Some((header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers)) } } diff --git a/tower-http/src/cors/max_age.rs b/tower-http/src/cors/max_age.rs index 8410eee9..b2dea75b 100644 --- a/tower-http/src/cors/max_age.rs +++ b/tower-http/src/cors/max_age.rs @@ -1,6 +1,9 @@ use std::{fmt, sync::Arc, time::Duration}; -use http::{request::Parts as RequestParts, HeaderValue}; +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; /// Holds configuration for how to set the [`Access-Control-Max-Age`][mdn] header. /// @@ -29,15 +32,17 @@ impl MaxAge { Self(MaxAgeInner::Fn(Arc::new(f))) } - pub(super) fn to_header_val( + pub(super) fn to_header( &self, origin: &HeaderValue, parts: &RequestParts, - ) -> Option { - match &self.0 { - MaxAgeInner::Exact(v) => v.clone(), - MaxAgeInner::Fn(c) => Some(c(origin, parts).as_secs().into()), - } + ) -> Option<(HeaderName, HeaderValue)> { + let max_age = match &self.0 { + MaxAgeInner::Exact(v) => v.clone()?, + MaxAgeInner::Fn(c) => c(origin, parts).as_secs().into(), + }; + + Some((header::ACCESS_CONTROL_MAX_AGE, max_age)) } } diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 595d0591..255f2628 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -572,13 +572,8 @@ where // These headers are applied to both preflight and subsequent regular CORS requests: // https://fetch.spec.whatwg.org/#http-responses - if let Some(value) = self.layer.allow_origin.to_header_val(&origin, &parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, value); - } - - if let Some(value) = self.layer.allow_credentials.to_header_val(&origin, &parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, value); - } + headers.extend(self.layer.allow_origin.to_header(&origin, &parts)); + headers.extend(self.layer.allow_credentials.to_header(&origin, &parts)); headers.append(header::VARY, header::ORIGIN.into()); headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); @@ -587,26 +582,16 @@ where // Return results immediately upon preflight request if parts.method == Method::OPTIONS { // These headers are applied only to preflight requests - if let Some(value) = self.layer.allow_methods.to_header_val(&parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, value); - } - - if let Some(value) = self.layer.allow_headers.to_header_val(&parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, value); - } - - if let Some(value) = self.layer.max_age.to_header_val(&origin, &parts) { - headers.insert(header::ACCESS_CONTROL_MAX_AGE, value); - } + headers.extend(self.layer.allow_methods.to_header(&parts)); + headers.extend(self.layer.allow_headers.to_header(&parts)); + headers.extend(self.layer.max_age.to_header(&origin, &parts)); ResponseFuture { inner: Kind::PreflightCall { headers }, } } else { // This header is applied only to non-preflight requests - if let Some(value) = self.layer.expose_headers.to_header_val(&parts) { - headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, value); - } + headers.extend(self.layer.expose_headers.to_header(&parts)); let req = Request::from_parts(parts, body); ResponseFuture { From 2350873718c4602788402835111aa262ea2388c9 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 7 Apr 2022 19:00:22 +0200 Subject: [PATCH 10/12] cors: Return constant CORS headers even if there is no origin in the request While this should generally be unnecessary, it doesn't hurt apart from the small amount of additional header bytes transmitted and can be helpful in weird edge cases. --- tower-http/src/cors/allow_credentials.rs | 4 ++-- tower-http/src/cors/allow_origin.rs | 4 ++-- tower-http/src/cors/max_age.rs | 4 ++-- tower-http/src/cors/mod.rs | 26 ++++-------------------- 4 files changed, 10 insertions(+), 28 deletions(-) diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index a0c0a6e5..3843def8 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -43,7 +43,7 @@ impl AllowCredentials { pub(super) fn to_header( &self, - origin: &HeaderValue, + origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { #[allow(clippy::declare_interior_mutable_const)] @@ -52,7 +52,7 @@ impl AllowCredentials { let allow_creds = match &self.0 { AllowCredentialsInner::Yes => true, AllowCredentialsInner::No => false, - AllowCredentialsInner::Predicate(c) => c(origin, parts), + AllowCredentialsInner::Predicate(c) => c(origin?, parts), }; allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs index 0b27b657..c14f7356 100644 --- a/tower-http/src/cors/allow_origin.rs +++ b/tower-http/src/cors/allow_origin.rs @@ -81,12 +81,12 @@ impl AllowOrigin { pub(super) fn to_header( &self, - origin: &HeaderValue, + origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { let allow_origin = match &self.0 { OriginInner::Const(v) => v.clone()?, - OriginInner::Predicate(c) => c(origin, parts).then(|| origin.to_owned())?, + OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(), }; Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin)) diff --git a/tower-http/src/cors/max_age.rs b/tower-http/src/cors/max_age.rs index b2dea75b..98189926 100644 --- a/tower-http/src/cors/max_age.rs +++ b/tower-http/src/cors/max_age.rs @@ -34,12 +34,12 @@ impl MaxAge { pub(super) fn to_header( &self, - origin: &HeaderValue, + origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { let max_age = match &self.0 { MaxAgeInner::Exact(v) => v.clone()?, - MaxAgeInner::Fn(c) => c(origin, parts).as_secs().into(), + MaxAgeInner::Fn(c) => c(origin?, parts).as_secs().into(), }; Some((header::ACCESS_CONTROL_MAX_AGE, max_age)) diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 255f2628..134f0653 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -552,28 +552,15 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let origin = req.headers().get(&header::ORIGIN).cloned(); - - // Only requests with an origin can be considered CORS requests: - // https://fetch.spec.whatwg.org/#http-requests - let origin = if let Some(origin) = origin { - origin - } else { - return ResponseFuture { - inner: Kind::NonCorsCall { - future: self.inner.call(req), - }, - }; - }; - let (parts, body) = req.into_parts(); + let origin = parts.headers.get(&header::ORIGIN); let mut headers = HeaderMap::new(); // These headers are applied to both preflight and subsequent regular CORS requests: // https://fetch.spec.whatwg.org/#http-responses - headers.extend(self.layer.allow_origin.to_header(&origin, &parts)); - headers.extend(self.layer.allow_credentials.to_header(&origin, &parts)); + headers.extend(self.layer.allow_origin.to_header(origin, &parts)); + headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); headers.append(header::VARY, header::ORIGIN.into()); headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); @@ -584,7 +571,7 @@ where // These headers are applied only to preflight requests headers.extend(self.layer.allow_methods.to_header(&parts)); headers.extend(self.layer.allow_headers.to_header(&parts)); - headers.extend(self.layer.max_age.to_header(&origin, &parts)); + headers.extend(self.layer.max_age.to_header(origin, &parts)); ResponseFuture { inner: Kind::PreflightCall { headers }, @@ -615,10 +602,6 @@ pin_project! { pin_project! { #[project = KindProj] enum Kind { - NonCorsCall { - #[pin] - future: F, - }, CorsCall { #[pin] future: F, @@ -645,7 +628,6 @@ where Poll::Ready(Ok(response)) } - KindProj::NonCorsCall { future } => future.poll(cx), KindProj::PreflightCall { headers } => { let mut response = Response::new(B::default()); mem::swap(response.headers_mut(), headers); From 0e34a8daa487a435c79d0900f125e5dbc7ce5819 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 7 Apr 2022 21:35:25 +0200 Subject: [PATCH 11/12] cors: Make vary headers configurable --- tower-http/src/cors/mod.rs | 58 +++++++++++++++++++++++++++++++++---- tower-http/src/cors/vary.rs | 51 ++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 tower-http/src/cors/vary.rs diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 134f0653..32f18461 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -50,9 +50,13 @@ use bytes::{BufMut, BytesMut}; use futures_core::ready; -use http::{header, HeaderMap, HeaderValue, Method, Request, Response}; +use http::{ + header::{self, HeaderName}, + HeaderMap, HeaderValue, Method, Request, Response, +}; use pin_project_lite::pin_project; use std::{ + array, future::Future, mem, pin::Pin, @@ -67,10 +71,11 @@ mod allow_methods; mod allow_origin; mod expose_headers; mod max_age; +mod vary; pub use self::{ allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, - allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, + allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, }; /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. @@ -87,6 +92,7 @@ pub struct CorsLayer { allow_origin: AllowOrigin, expose_headers: ExposeHeaders, max_age: MaxAge, + vary: Vary, } #[allow(clippy::declare_interior_mutable_const)] @@ -108,6 +114,7 @@ impl CorsLayer { allow_origin: Default::default(), expose_headers: Default::default(), max_age: Default::default(), + vary: Default::default(), } } @@ -351,6 +358,24 @@ impl CorsLayer { self.expose_headers = headers.into(); self } + + /// Set the value(s) of the [`Vary`][mdn] header. + /// + /// In contrast to the other headers, this one has a non-empty default of + /// [`preflight_request_headers()`]. + /// + /// You only need to set this is you want to remove some of these defaults, + /// or if you use a closure for one of the other headers and want to add a + /// vary header accordingly. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary + pub fn vary(mut self, headers: T) -> Self + where + T: Into, + { + self.vary = headers.into(); + self + } } /// Represents a wildcard value (`*`) used with some CORS headers such as @@ -559,12 +584,23 @@ where // These headers are applied to both preflight and subsequent regular CORS requests: // https://fetch.spec.whatwg.org/#http-responses + headers.extend(self.layer.allow_origin.to_header(origin, &parts)); headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); - headers.append(header::VARY, header::ORIGIN.into()); - headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_METHOD.into()); - headers.append(header::VARY, header::ACCESS_CONTROL_REQUEST_HEADERS.into()); + let mut vary_headers = self.layer.vary.values(); + if let Some(first) = vary_headers.next() { + let mut header = match headers.entry(header::VARY) { + header::Entry::Occupied(_) => { + unreachable!("no vary header inserted up to this point") + } + header::Entry::Vacant(v) => v.insert_entry(first), + }; + + for val in vary_headers { + header.append(val); + } + } // Return results immediately upon preflight request if parts.method == Method::OPTIONS { @@ -665,3 +701,15 @@ fn ensure_usable_cors_rules(layer: &CorsLayer) { ); } } + +/// Returns an iterator over the three request headers that may be involved in a CORS preflight request. +/// +/// This is the default set of header names returned in the `vary` header +pub fn preflight_request_headers() -> impl Iterator { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + array::IntoIter::new([ + header::ORIGIN, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::ACCESS_CONTROL_REQUEST_HEADERS, + ]) +} diff --git a/tower-http/src/cors/vary.rs b/tower-http/src/cors/vary.rs new file mode 100644 index 00000000..b1dddc36 --- /dev/null +++ b/tower-http/src/cors/vary.rs @@ -0,0 +1,51 @@ +use std::array; + +use http::{header::HeaderName, HeaderValue}; + +use super::preflight_request_headers; + +/// Holds configuration for how to set the [`Vary`][mdn] header. +/// +/// See [`CorsLayer::vary`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary +/// [`CorsLayer::vary`]: super::CorsLayer::vary +#[derive(Clone, Debug)] +pub struct Vary(Vec); + +impl Vary { + /// Set the list of header names to return as vary header values + /// + /// See [`CorsLayer::vary`] for more details. + /// + /// [`CorsLayer::vary`]: super::CorsLayer::vary + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(headers.into_iter().map(Into::into).collect()) + } + + pub(super) fn values(&self) -> impl Iterator + '_ { + self.0.iter().cloned() + } +} + +impl Default for Vary { + fn default() -> Self { + Self::list(preflight_request_headers()) + } +} + +impl From<[HeaderName; N]> for Vary { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for Vary { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} From 98ee61d58e75059f42191a984164eb20b6eef4db Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 21 Apr 2022 21:34:30 +0200 Subject: [PATCH 12/12] Update changelog --- tower-http/CHANGELOG.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 27f55daa..d7295757 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,11 +9,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- **cors**: Added `CorsLayer::very_permissive` which is like + `CorsLayer::permissive` except it (truly) allows credentials. This is made + possible by mirroring the request's origin as well as method and headers + back as CORS-whitelisted ones +* **cors**: Allow customizing the value(s) for the `Vary` header ## Changed -- None. +- **cors**: Removed `allow-credentials: true` from `CorsLayer::permissive`. + It never actually took effect in compliant browsers because it is mutually + exclusive with the `*` wildcard (`Any`) on origins, methods and headers +- **cors**: Rewrote the CORS middleware. Almost all existing usage patterns + will continue to work. (BREAKING) +- **cors**: The CORS middleware will now panic if you try to use `Any` in + combination with `.allow_credentials(true)`. This configuration worked + before, but resulted in browsers ignoring the `allow-credentials` header, + which defeats the purpose of setting it and can be very annoying to debug. ## Removed