diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 27f55daa..d7295757 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,11 +9,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- **cors**: Added `CorsLayer::very_permissive` which is like + `CorsLayer::permissive` except it (truly) allows credentials. This is made + possible by mirroring the request's origin as well as method and headers + back as CORS-whitelisted ones +* **cors**: Allow customizing the value(s) for the `Vary` header ## Changed -- None. +- **cors**: Removed `allow-credentials: true` from `CorsLayer::permissive`. + It never actually took effect in compliant browsers because it is mutually + exclusive with the `*` wildcard (`Any`) on origins, methods and headers +- **cors**: Rewrote the CORS middleware. Almost all existing usage patterns + will continue to work. (BREAKING) +- **cors**: The CORS middleware will now panic if you try to use `Any` in + combination with `.allow_credentials(true)`. This configuration worked + before, but resulted in browsers ignoring the `allow-credentials` header, + which defeats the purpose of setting it and can be very annoying to debug. ## Removed diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs new file mode 100644 index 00000000..3843def8 --- /dev/null +++ b/tower-http/src/cors/allow_credentials.rs @@ -0,0 +1,94 @@ +use std::{fmt, sync::Arc}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header. +/// +/// See [`CorsLayer::allow_credentials`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials +/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials +#[derive(Clone, Default)] +#[must_use] +pub struct AllowCredentials(AllowCredentialsInner); + +impl AllowCredentials { + /// Allow credentials for all requests + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn yes() -> Self { + Self(AllowCredentialsInner::Yes) + } + + /// Allow credentials for some requests, based on a given predicate + /// + /// See [`CorsLayer::allow_credentials`] for more details. + /// + /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowCredentialsInner::Predicate(Arc::new(f))) + } + + pub(super) fn is_true(&self) -> bool { + matches!(&self.0, AllowCredentialsInner::Yes) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + #[allow(clippy::declare_interior_mutable_const)] + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + let allow_creds = match &self.0 { + AllowCredentialsInner::Yes => true, + AllowCredentialsInner::No => false, + AllowCredentialsInner::Predicate(c) => c(origin?, parts), + }; + + allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) + } +} + +impl From for AllowCredentials { + fn from(v: bool) -> Self { + match v { + true => Self(AllowCredentialsInner::Yes), + false => Self(AllowCredentialsInner::No), + } + } +} + +impl fmt::Debug for AllowCredentials { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowCredentialsInner::Yes => f.debug_tuple("Yes").finish(), + AllowCredentialsInner::No => f.debug_tuple("No").finish(), + AllowCredentialsInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowCredentialsInner { + Yes, + No, + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowCredentialsInner { + fn default() -> Self { + Self::No + } +} diff --git a/tower-http/src/cors/allow_headers.rs b/tower-http/src/cors/allow_headers.rs new file mode 100644 index 00000000..06c19928 --- /dev/null +++ b/tower-http/src/cors/allow_headers.rs @@ -0,0 +1,112 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Headers`][mdn] header. +/// +/// See [`CorsLayer::allow_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers +/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers +#[derive(Clone, Default)] +#[must_use] +pub struct AllowHeaders(AllowHeadersInner); + +impl AllowHeaders { + /// Allow any headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn any() -> Self { + Self(AllowHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple allowed headers + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(AllowHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + /// Allow any headers, by mirroring the preflight [`Access-Control-Request-Headers`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_headers`] for more details. + /// + /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers + pub fn mirror_request() -> Self { + Self(AllowHeadersInner::MirrorRequest) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_headers = match &self.0 { + AllowHeadersInner::Const(v) => v.clone()?, + AllowHeadersInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_HEADERS)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers)) + } +} + +impl fmt::Debug for AllowHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowHeadersInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From for AllowHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<[HeaderName; N]> for AllowHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowHeaders { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowHeadersInner { + Const(Option), + MirrorRequest, +} + +impl Default for AllowHeadersInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/allow_methods.rs b/tower-http/src/cors/allow_methods.rs new file mode 100644 index 00000000..df1a3cbd --- /dev/null +++ b/tower-http/src/cors/allow_methods.rs @@ -0,0 +1,132 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, + Method, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Methods`][mdn] header. +/// +/// See [`CorsLayer::allow_methods`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods +/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods +#[derive(Clone, Default)] +#[must_use] +pub struct AllowMethods(AllowMethodsInner); + +impl AllowMethods { + /// Allow any method by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn any() -> Self { + Self(AllowMethodsInner::Const(Some(WILDCARD))) + } + + /// Set a single allowed method + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn exact(method: Method) -> Self { + Self(AllowMethodsInner::Const(Some( + HeaderValue::from_str(method.as_str()).unwrap(), + ))) + } + + /// Set multiple allowed methods + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + pub fn list(methods: I) -> Self + where + I: IntoIterator, + { + Self(AllowMethodsInner::Const(separated_by_commas( + methods + .into_iter() + .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), + ))) + } + + /// Allow any method, by mirroring the preflight [`Access-Control-Request-Method`][mdn] + /// header. + /// + /// See [`CorsLayer::allow_methods`] for more details. + /// + /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Method + pub fn mirror_request() -> Self { + Self(AllowMethodsInner::MirrorRequest) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let allow_methods = match &self.0 { + AllowMethodsInner::Const(v) => v.clone()?, + AllowMethodsInner::MirrorRequest => parts + .headers + .get(header::ACCESS_CONTROL_REQUEST_METHOD)? + .clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods)) + } +} + +impl fmt::Debug for AllowMethods { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + AllowMethodsInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + AllowMethodsInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), + } + } +} + +impl From for AllowMethods { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From for AllowMethods { + fn from(method: Method) -> Self { + Self::exact(method) + } +} + +impl From<[Method; N]> for AllowMethods { + fn from(arr: [Method; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowMethods { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum AllowMethodsInner { + Const(Option), + MirrorRequest, +} + +impl Default for AllowMethodsInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/allow_origin.rs b/tower-http/src/cors/allow_origin.rs new file mode 100644 index 00000000..c14f7356 --- /dev/null +++ b/tower-http/src/cors/allow_origin.rs @@ -0,0 +1,142 @@ +use std::{array, fmt, sync::Arc}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header. +/// +/// See [`CorsLayer::allow_origin`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin +/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin +#[derive(Clone, Default)] +#[must_use] +pub struct AllowOrigin(OriginInner); + +impl AllowOrigin { + /// Allow any origin by sending a wildcard (`*`) + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn any() -> Self { + Self(OriginInner::Const(Some(WILDCARD))) + } + + /// Set a single allowed origin + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn exact(origin: HeaderValue) -> Self { + Self(OriginInner::Const(Some(origin))) + } + + /// Set multiple allowed origins + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn list(origins: I) -> Self + where + I: IntoIterator, + { + Self(OriginInner::Const(separated_by_commas( + origins.into_iter().map(Into::into), + ))) + } + + /// Set the allowed origins from a predicate + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(OriginInner::Predicate(Arc::new(f))) + } + + /// Allow any origin, by mirroring the request origin + /// + /// This is equivalent to + /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate]. + /// + /// See [`CorsLayer::allow_origin`] for more details. + /// + /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin + pub fn mirror_request() -> Self { + Self::predicate(|_, _| true) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, OriginInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + let allow_origin = match &self.0 { + OriginInner::Const(v) => v.clone()?, + OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(), + }; + + Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin)) + } +} + +impl fmt::Debug for AllowOrigin { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +impl From for AllowOrigin { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From for AllowOrigin { + fn from(val: HeaderValue) -> Self { + Self::exact(val) + } +} + +impl From<[HeaderValue; N]> for AllowOrigin { + fn from(arr: [HeaderValue; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for AllowOrigin { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum OriginInner { + Const(Option), + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for OriginInner { + fn default() -> Self { + Self::Const(None) + } +} diff --git a/tower-http/src/cors/expose_headers.rs b/tower-http/src/cors/expose_headers.rs new file mode 100644 index 00000000..2b1a2267 --- /dev/null +++ b/tower-http/src/cors/expose_headers.rs @@ -0,0 +1,94 @@ +use std::{array, fmt}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +use super::{separated_by_commas, Any, WILDCARD}; + +/// Holds configuration for how to set the [`Access-Control-Expose-Headers`][mdn] header. +/// +/// See [`CorsLayer::expose_headers`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers +/// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers +#[derive(Clone, Default)] +#[must_use] +pub struct ExposeHeaders(ExposeHeadersInner); + +impl ExposeHeaders { + /// Expose any / all headers by sending a wildcard (`*`) + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn any() -> Self { + Self(ExposeHeadersInner::Const(Some(WILDCARD))) + } + + /// Set multiple exposed header names + /// + /// See [`CorsLayer::expose_headers`] for more details. + /// + /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(ExposeHeadersInner::Const(separated_by_commas( + headers.into_iter().map(Into::into), + ))) + } + + #[allow(clippy::borrow_interior_mutable_const)] + pub(super) fn is_wildcard(&self) -> bool { + matches!(&self.0, ExposeHeadersInner::Const(Some(v)) if v == WILDCARD) + } + + pub(super) fn to_header(&self, _parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { + let expose_headers = match &self.0 { + ExposeHeadersInner::Const(v) => v.clone()?, + }; + + Some((header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers)) + } +} + +impl fmt::Debug for ExposeHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + ExposeHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), + } + } +} + +impl From for ExposeHeaders { + fn from(_: Any) -> Self { + Self::any() + } +} + +impl From<[HeaderName; N]> for ExposeHeaders { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for ExposeHeaders { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +} + +#[derive(Clone)] +enum ExposeHeadersInner { + Const(Option), +} + +impl Default for ExposeHeadersInner { + fn default() -> Self { + ExposeHeadersInner::Const(None) + } +} diff --git a/tower-http/src/cors/max_age.rs b/tower-http/src/cors/max_age.rs new file mode 100644 index 00000000..98189926 --- /dev/null +++ b/tower-http/src/cors/max_age.rs @@ -0,0 +1,74 @@ +use std::{fmt, sync::Arc, time::Duration}; + +use http::{ + header::{self, HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Max-Age`][mdn] header. +/// +/// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age +#[derive(Clone, Default)] +#[must_use] +pub struct MaxAge(MaxAgeInner); + +impl MaxAge { + /// Set a static max-age value + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn exact(max_age: Duration) -> Self { + Self(MaxAgeInner::Exact(Some(max_age.as_secs().into()))) + } + + /// Set the max-age based on the preflight request parts + /// + /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. + pub fn dynamic(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> Duration + Send + Sync + 'static, + { + Self(MaxAgeInner::Fn(Arc::new(f))) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + let max_age = match &self.0 { + MaxAgeInner::Exact(v) => v.clone()?, + MaxAgeInner::Fn(c) => c(origin?, parts).as_secs().into(), + }; + + Some((header::ACCESS_CONTROL_MAX_AGE, max_age)) + } +} + +impl fmt::Debug for MaxAge { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + MaxAgeInner::Exact(inner) => f.debug_tuple("Exact").field(inner).finish(), + MaxAgeInner::Fn(_) => f.debug_tuple("Fn").finish(), + } + } +} + +impl From for MaxAge { + fn from(max_age: Duration) -> Self { + Self::exact(max_age) + } +} + +#[derive(Clone)] +enum MaxAgeInner { + Exact(Option), + Fn(Arc Fn(&'a HeaderValue, &'a RequestParts) -> Duration + Send + Sync + 'static>), +} + +impl Default for MaxAgeInner { + fn default() -> Self { + Self::Exact(None) + } +} diff --git a/tower-http/src/cors.rs b/tower-http/src/cors/mod.rs similarity index 51% rename from tower-http/src/cors.rs rename to tower-http/src/cors/mod.rs index 9e11cb61..32f18461 100644 --- a/tower-http/src/cors.rs +++ b/tower-http/src/cors/mod.rs @@ -17,7 +17,7 @@ //! # async fn main() -> Result<(), Box> { //! let cors = CorsLayer::new() //! // allow `GET` and `POST` when accessing the resource -//! .allow_methods(vec![Method::GET, Method::POST]) +//! .allow_methods([Method::GET, Method::POST]) //! // allow requests from any origin //! .allow_origin(Any); //! @@ -46,39 +46,53 @@ //! //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS +#![allow(clippy::enum_variant_names)] + use bytes::{BufMut, BytesMut}; use futures_core::ready; use http::{ - header::{self, HeaderName, HeaderValue}, - request::Parts, - HeaderMap, Method, Request, Response, StatusCode, + header::{self, HeaderName}, + HeaderMap, HeaderValue, Method, Request, Response, }; use pin_project_lite::pin_project; use std::{ - fmt, + array, future::Future, mem, pin::Pin, - sync::Arc, task::{Context, Poll}, - time::Duration, }; use tower_layer::Layer; use tower_service::Service; +mod allow_credentials; +mod allow_headers; +mod allow_methods; +mod allow_origin; +mod expose_headers; +mod max_age; +mod vary; + +pub use self::{ + allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, + allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, +}; + /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. /// /// See the [module docs](crate::cors) for an example. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] +#[must_use] pub struct CorsLayer { - allow_credentials: bool, - allow_headers: Option, - allow_methods: Option, - allow_origin: Option>, - expose_headers: Option, - max_age: Option, + allow_credentials: AllowCredentials, + allow_headers: AllowHeaders, + allow_methods: AllowMethods, + allow_origin: AllowOrigin, + expose_headers: ExposeHeaders, + max_age: MaxAge, + vary: Vary, } #[allow(clippy::declare_interior_mutable_const)] @@ -87,37 +101,51 @@ const WILDCARD: HeaderValue = HeaderValue::from_static("*"); impl CorsLayer { /// Create a new `CorsLayer`. /// - /// This creates a restrictive configuration. Use the builder methods to - /// customize the behavior. + /// No headers are sent by default. Use the builder methods to customize + /// the behavior. + /// + /// You need to set at least an allowed origin for browsers to make + /// successful cross-origin requests to your service. pub fn new() -> Self { Self { - allow_credentials: false, - allow_headers: None, - allow_methods: None, - allow_origin: None, - expose_headers: None, - max_age: None, + allow_credentials: Default::default(), + allow_headers: Default::default(), + allow_methods: Default::default(), + allow_origin: Default::default(), + expose_headers: Default::default(), + max_age: Default::default(), + vary: Default::default(), } } - /// A very permissive configuration suitable for development: + /// A permissive configuration: /// - /// - Credentials allowed. /// - All request headers allowed. /// - All methods allowed. /// - All origins allowed. /// - All headers exposed. - /// - Max age set to 1 hour. - /// - /// Note this is not recommended for production use. pub fn permissive() -> Self { Self::new() - .allow_credentials(true) .allow_headers(Any) .allow_methods(Any) .allow_origin(Any) .expose_headers(Any) - .max_age(Duration::from_secs(60 * 60)) + } + + /// A very permissive configuration: + /// + /// - **Credentials allowed.** + /// - The method received in `Access-Control-Request-Method` is sent back + /// as an allowed method. + /// - The origin of the preflight request is sent back as an allowed origin. + /// - The header names received in `Access-Control-Request-Headers` are sent + /// back as allowed headers. + /// - No headers are currently exposed, but this may change in the future. + pub fn very_permissive() -> Self { + Self::new() + .allow_headers(AllowHeaders::mirror_request()) + .allow_methods(AllowMethods::mirror_request()) + .allow_origin(AllowOrigin::mirror_request()) } /// Set the [`Access-Control-Allow-Credentials`][mdn] header. @@ -129,8 +157,11 @@ impl CorsLayer { /// ``` /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - pub fn allow_credentials(mut self, allow_credentials: bool) -> Self { - self.allow_credentials = allow_credentials; + pub fn allow_credentials(mut self, allow_credentials: T) -> Self + where + T: Into, + { + self.allow_credentials = allow_credentials.into(); self } @@ -140,7 +171,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::header::{AUTHORIZATION, ACCEPT}; /// - /// let layer = CorsLayer::new().allow_headers(vec![AUTHORIZATION, ACCEPT]); + /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]); /// ``` /// /// All headers can be allowed with @@ -158,22 +189,19 @@ impl CorsLayer { /// `Access-Control-Request-Headers`. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - pub fn allow_headers(mut self, headers: I) -> Self + pub fn allow_headers(mut self, headers: T) -> Self where - I: Into>>, + T: Into, { - self.allow_headers = match headers.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)), - }; + self.allow_headers = headers.into(); self } /// Set the value of the [`Access-Control-Max-Age`][mdn] header. /// /// ``` - /// use tower_http::cors::CorsLayer; /// use std::time::Duration; + /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10); /// ``` @@ -185,9 +213,34 @@ impl CorsLayer { /// precedence when the Access-Control-Max-Age is greater. For more details /// see [mdn]. /// + /// If you need more flexibility, you can use supply a function which can + /// dynamically decide the max-age based on the origin and other parts of + /// each preflight request: + /// + /// ``` + /// # struct MyServerConfig { cors_max_age: Duration } + /// use std::time::Duration; + /// + /// use http::{request::Parts as RequestParts, HeaderValue}; + /// use tower_http::cors::{CorsLayer, MaxAge}; + /// + /// let layer = CorsLayer::new().max_age(MaxAge::dynamic( + /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration { + /// // Let's say you want to be able to reload your config at + /// // runtime and have another middleware that always inserts + /// // the current config into the request extensions + /// let config = parts.extensions.get::().unwrap(); + /// config.cors_max_age + /// }, + /// )); + /// ``` + /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - pub fn max_age(mut self, max_age: Duration) -> Self { - self.max_age = Some(max_age.as_secs().into()); + pub fn max_age(mut self, max_age: T) -> Self + where + T: Into, + { + self.max_age = max_age.into(); self } @@ -197,7 +250,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::Method; /// - /// let layer = CorsLayer::new().allow_methods(vec![Method::GET, Method::POST]); + /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]); /// ``` /// /// All methods can be allowed with @@ -214,40 +267,34 @@ impl CorsLayer { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(mut self, methods: T) -> Self where - T: Into>>, + T: Into, { - self.allow_methods = match methods.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(methods) => separated_by_commas( - methods - .into_iter() - .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), - ), - }; + self.allow_methods = methods.into(); self } /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; + /// use http::HeaderValue; + /// use tower_http::cors::CorsLayer; /// - /// let layer = CorsLayer::new().allow_origin(Origin::exact( - /// "http://example.com".parse().unwrap(), - /// )); + /// let layer = CorsLayer::new().allow_origin( + /// "http://example.com".parse::().unwrap(), + /// ); /// ``` /// /// Multiple origins can be allowed with /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; + /// use tower_http::cors::CorsLayer; /// - /// let origins = vec![ + /// let origins = [ /// "http://example.com".parse().unwrap(), /// "http://api.example.com".parse().unwrap(), /// ]; /// - /// let layer = CorsLayer::new().allow_origin(Origin::list(origins)); + /// let layer = CorsLayer::new().allow_origin(origins); /// ``` /// /// All origins can be allowed with @@ -261,14 +308,14 @@ impl CorsLayer { /// You can also use a closure /// /// ``` - /// use tower_http::cors::{CorsLayer, Origin}; - /// use http::{HeaderValue, request::Parts}; + /// use tower_http::cors::{CorsLayer, AllowOrigin}; + /// use http::{request::Parts as RequestParts, HeaderValue}; /// - /// let layer = CorsLayer::new().allow_origin( - /// Origin::predicate(|origin: &HeaderValue, _request_head: &Parts| { + /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate( + /// |origin: &HeaderValue, _request_parts: &RequestParts| { /// origin.as_bytes().ends_with(b".rust-lang.org") - /// }) - /// ); + /// }, + /// )); /// ``` /// /// Note that multiple calls to this method will override any previous @@ -277,9 +324,9 @@ impl CorsLayer { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(mut self, origin: T) -> Self where - T: Into>, + T: Into, { - self.allow_origin = Some(origin.into()); + self.allow_origin = origin.into(); self } @@ -289,7 +336,7 @@ impl CorsLayer { /// use tower_http::cors::CorsLayer; /// use http::header::CONTENT_ENCODING; /// - /// let layer = CorsLayer::new().expose_headers(vec![CONTENT_ENCODING]); + /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]); /// ``` /// /// All headers can be allowed with @@ -304,14 +351,29 @@ impl CorsLayer { /// calls. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers - pub fn expose_headers(mut self, headers: I) -> Self + pub fn expose_headers(mut self, headers: T) -> Self where - I: Into>>, + T: Into, { - self.expose_headers = match headers.into().0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)), - }; + self.expose_headers = headers.into(); + self + } + + /// Set the value(s) of the [`Vary`][mdn] header. + /// + /// In contrast to the other headers, this one has a non-empty default of + /// [`preflight_request_headers()`]. + /// + /// You only need to set this is you want to remove some of these defaults, + /// or if you use a closure for one of the other headers and want to add a + /// vary header accordingly. + /// + /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary + pub fn vary(mut self, headers: T) -> Self + where + T: Into, + { + self.vary = headers.into(); self } } @@ -319,6 +381,7 @@ impl CorsLayer { /// Represents a wildcard value (`*`) used with some CORS headers such as /// [`CorsLayer::allow_methods`]. #[derive(Debug, Clone, Copy)] +#[must_use] pub struct Any; /// Represents a wildcard value (`*`) used with some CORS headers such as @@ -328,48 +391,6 @@ pub fn any() -> Any { Any } -/// Used to make methods like [`CorsLayer::allow_methods`] more convenient to call. -/// -/// You shouldn't have to use this type directly. -#[derive(Debug, Clone, Copy)] -pub struct AnyOr(AnyOrInner); - -#[derive(Debug, Clone, Copy)] -enum AnyOrInner { - Any, - Value(T), -} - -impl From for AnyOr { - fn from(origin: Origin) -> Self { - AnyOr(AnyOrInner::Value(origin)) - } -} - -impl From for AnyOr { - fn from(_: Any) -> Self { - AnyOr(AnyOrInner::Any) - } -} - -impl From for AnyOr> -where - I: IntoIterator, -{ - fn from(methods: I) -> Self { - AnyOr(AnyOrInner::Value(methods.into_iter().collect())) - } -} - -impl From for AnyOr> -where - I: IntoIterator, -{ - fn from(headers: I) -> Self { - AnyOr(AnyOrInner::Value(headers.into_iter().collect())) - } -} - fn separated_by_commas(mut iter: I) -> Option where I: Iterator, @@ -399,6 +420,8 @@ impl Layer for CorsLayer { type Service = Cors; fn layer(&self, inner: S) -> Self::Service { + ensure_usable_cors_rules(self); + Cors { inner, layer: self.clone(), @@ -412,6 +435,7 @@ impl Layer for CorsLayer { /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] +#[must_use] pub struct Cors { inner: S, layer: CorsLayer, @@ -420,8 +444,7 @@ pub struct Cors { impl Cors { /// Create a new `Cors`. /// - /// This creates a restrictive configuration. Use the builder methods to - /// customize the behavior. + /// See [`CorsLayer::new`] for more details. pub fn new(inner: S) -> Self { Self { inner, @@ -429,7 +452,7 @@ impl Cors { } } - /// A very permissive configuration suitable for development. + /// A permissive configuration. /// /// See [`CorsLayer::permissive`] for more details. pub fn permissive(inner: S) -> Self { @@ -439,6 +462,16 @@ impl Cors { } } + /// A very permissive configuration. + /// + /// See [`CorsLayer::very_permissive`] for more details. + pub fn very_permissive(inner: S) -> Self { + Self { + inner, + layer: CorsLayer::very_permissive(), + } + } + define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware. @@ -453,7 +486,10 @@ impl Cors { /// See [`CorsLayer::allow_credentials`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - pub fn allow_credentials(self, allow_credentials: bool) -> Self { + pub fn allow_credentials(self, allow_credentials: T) -> Self + where + T: Into, + { self.map_layer(|layer| layer.allow_credentials(allow_credentials)) } @@ -462,9 +498,9 @@ impl Cors { /// See [`CorsLayer::allow_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - pub fn allow_headers(self, headers: I) -> Self + pub fn allow_headers(self, headers: T) -> Self where - I: Into>>, + T: Into, { self.map_layer(|layer| layer.allow_headers(headers)) } @@ -474,7 +510,10 @@ impl Cors { /// See [`CorsLayer::max_age`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - pub fn max_age(self, max_age: Duration) -> Self { + pub fn max_age(self, max_age: T) -> Self + where + T: Into, + { self.map_layer(|layer| layer.max_age(max_age)) } @@ -485,7 +524,7 @@ impl Cors { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(self, methods: T) -> Self where - T: Into>>, + T: Into, { self.map_layer(|layer| layer.allow_methods(methods)) } @@ -497,7 +536,7 @@ impl Cors { /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(self, origin: T) -> Self where - T: Into>, + T: Into, { self.map_layer(|layer| layer.allow_origin(origin)) } @@ -507,9 +546,9 @@ impl Cors { /// See [`CorsLayer::expose_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers - pub fn expose_headers(self, headers: I) -> Self + pub fn expose_headers(self, headers: T) -> Self where - I: Into>>, + T: Into, { self.map_layer(|layer| layer.expose_headers(headers)) } @@ -521,139 +560,6 @@ impl Cors { self.layer = f(self.layer); self } - - fn is_valid_request_method(&self, method: &HeaderValue) -> bool { - if let Some(allow_methods) = &self.layer.allow_methods { - #[allow(clippy::borrow_interior_mutable_const)] - if allow_methods == WILDCARD { - return true; - } - - allow_methods - .as_bytes() - .split(|&byte| byte == b',') - .any(|bytes| bytes == method.as_bytes()) - } else { - false - } - } - - fn make_response_header_map(&self) -> HeaderMap { - #[allow(clippy::declare_interior_mutable_const)] - const TRUE: HeaderValue = HeaderValue::from_static("true"); - - let mut headers = HeaderMap::new(); - - if self.layer.allow_credentials { - headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE); - } - - if let Some(expose_headers) = self.layer.expose_headers.clone() { - headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); - } - - headers - } - - fn make_preflight_header_map(&self, origin: HeaderValue, parts: &Parts) -> HeaderMap { - let mut headers = self.make_response_header_map(); - - if let Some(allow_origin) = &self.layer.allow_origin { - if let Some(origin) = allow_origin.to_header_val(origin, parts) { - headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - } - } - - if let Some(allow_methods) = &self.layer.allow_methods { - headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods.clone()); - } - - if let Some(allow_headers) = &self.layer.allow_headers { - headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers.clone()); - } - - if let Some(max_age) = self.layer.max_age.clone() { - headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age); - } - - headers - } -} - -/// Represents a [`Access-Control-Allow-Origin`][mdn] header. -/// -/// See [`CorsLayer::allow_origin`] for more details. -/// -/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin -#[derive(Clone)] -pub struct Origin(OriginInner); - -impl Origin { - /// Set a single allow origin target - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn exact(origin: HeaderValue) -> Self { - Self(OriginInner::Const(Some(origin))) - } - - /// Set multiple allow origin targets - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn list(origins: I) -> Self - where - I: IntoIterator, - { - Self(OriginInner::Const(separated_by_commas( - origins.into_iter().map(Into::into), - ))) - } - - /// Set the allowed origins from a predicate - /// - /// See [`CorsLayer::allow_origin`] for more details. - pub fn predicate(f: F) -> Self - where - F: Fn(&HeaderValue, &Parts) -> bool + Send + Sync + 'static, - { - Self(OriginInner::Closure(Arc::new(f))) - } - - fn to_header_val(&self, origin: HeaderValue, parts: &Parts) -> Option { - match &self.0 { - OriginInner::Const(v) => v.clone(), - OriginInner::Closure(c) => { - if c(&origin, parts) { - Some(origin) - } else { - None - } - } - } - } -} - -impl AnyOr { - fn to_header_val(&self, origin: HeaderValue, parts: &Parts) -> Option { - match &self.0 { - AnyOrInner::Any => Some(WILDCARD), - AnyOrInner::Value(o) => o.to_header_val(origin, parts), - } - } -} - -impl fmt::Debug for Origin { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), - OriginInner::Closure(_) => f.debug_tuple("Closure").finish(), - } - } -} - -#[derive(Clone)] -enum OriginInner { - Const(Option), - Closure(Arc Fn(&'a HeaderValue, &'a Parts) -> bool + Send + Sync + 'static>), } impl Service> for Cors @@ -666,59 +572,57 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + ensure_usable_cors_rules(&self.layer); self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { - let origin = req.headers().get(&header::ORIGIN).cloned(); + let (parts, body) = req.into_parts(); + let origin = parts.headers.get(&header::ORIGIN); - let origin = if let Some(origin) = origin { - origin - } else { - // This is not a CORS request if there is no Origin header - return ResponseFuture { - inner: Kind::NonCorsCall { - future: self.inner.call(req), - }, - }; - }; + let mut headers = HeaderMap::new(); - let (parts, body) = req.into_parts(); + // These headers are applied to both preflight and subsequent regular CORS requests: + // https://fetch.spec.whatwg.org/#http-responses - // Return results immediately upon preflight request - if parts.method == Method::OPTIONS { - // the method the real request will be made with - match parts.headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) { - Some(request_method) if self.is_valid_request_method(request_method) => {} - _ => { - return ResponseFuture { - inner: Kind::InvalidCorsCall, - }; - } - } + headers.extend(self.layer.allow_origin.to_header(origin, &parts)); + headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); - return ResponseFuture { - inner: Kind::PreflightCall { - headers: self.make_preflight_header_map(origin, &parts), - }, + let mut vary_headers = self.layer.vary.values(); + if let Some(first) = vary_headers.next() { + let mut header = match headers.entry(header::VARY) { + header::Entry::Occupied(_) => { + unreachable!("no vary header inserted up to this point") + } + header::Entry::Vacant(v) => v.insert_entry(first), }; - } - let req = Request::from_parts(parts, body); + for val in vary_headers { + header.append(val); + } + } - let mut headers = self.make_response_header_map(); - headers.insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - response_origin(self.layer.allow_origin.as_ref().unwrap(), &origin), - ); + // Return results immediately upon preflight request + if parts.method == Method::OPTIONS { + // These headers are applied only to preflight requests + headers.extend(self.layer.allow_methods.to_header(&parts)); + headers.extend(self.layer.allow_headers.to_header(&parts)); + headers.extend(self.layer.max_age.to_header(origin, &parts)); - apply_vary_headers(&mut headers); + ResponseFuture { + inner: Kind::PreflightCall { headers }, + } + } else { + // This header is applied only to non-preflight requests + headers.extend(self.layer.expose_headers.to_header(&parts)); - ResponseFuture { - inner: Kind::CorsCall { - future: self.inner.call(req), - headers, - }, + let req = Request::from_parts(parts, body); + ResponseFuture { + inner: Kind::CorsCall { + future: self.inner.call(req), + headers, + }, + } } } } @@ -734,10 +638,6 @@ pin_project! { pin_project! { #[project = KindProj] enum Kind { - NonCorsCall { - #[pin] - future: F, - }, CorsCall { #[pin] future: F, @@ -746,8 +646,6 @@ pin_project! { PreflightCall { headers: HeaderMap, }, - InvalidCorsCall, - InvalidOrigin, } } @@ -766,79 +664,52 @@ where Poll::Ready(Ok(response)) } - KindProj::NonCorsCall { future } => future.poll(cx), KindProj::PreflightCall { headers } => { - apply_vary_headers(headers); - let mut response = Response::new(B::default()); mem::swap(response.headers_mut(), headers); - Poll::Ready(Ok(response)) - } - KindProj::InvalidCorsCall => { - let response = Response::builder() - .status(StatusCode::OK) - .body(B::default()) - .unwrap(); - - Poll::Ready(Ok(response)) - } - KindProj::InvalidOrigin => { - let response = Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(B::default()) - .unwrap(); - Poll::Ready(Ok(response)) } } } } -fn apply_vary_headers(headers: &mut http::HeaderMap) { - const VARY_HEADERS: [HeaderName; 3] = [ - header::ORIGIN, - header::ACCESS_CONTROL_REQUEST_METHOD, - header::ACCESS_CONTROL_REQUEST_HEADERS, - ]; +fn ensure_usable_cors_rules(layer: &CorsLayer) { + if layer.allow_credentials.is_true() { + assert!( + !layer.allow_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Headers: *`" + ); - for h in &VARY_HEADERS { - headers.append(header::VARY, HeaderValue::from_static(h.as_str())); - } -} + assert!( + !layer.allow_methods.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Methods: *`" + ); -fn response_origin(allow_origin: &AnyOr, origin: &HeaderValue) -> HeaderValue { - if let AnyOrInner::Any = &allow_origin.0 { - WILDCARD - } else { - origin.clone() + assert!( + !layer.allow_origin.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Allow-Origin: *`" + ); + + assert!( + !layer.expose_headers.is_wildcard(), + "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ + with `Access-Control-Expose-Headers: *`" + ); } } -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - - #[test] - fn test_is_valid_request_method() { - let cors = Cors::new(()).allow_methods(vec![Method::GET, Method::POST]); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - - let cors = Cors::new(()); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(!cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - - let cors = Cors::new(()).allow_methods(Any); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - - let cors = Cors::new(()).allow_methods(Any); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST"))); - assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS"))); - } +/// Returns an iterator over the three request headers that may be involved in a CORS preflight request. +/// +/// This is the default set of header names returned in the `vary` header +pub fn preflight_request_headers() -> impl Iterator { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + array::IntoIter::new([ + header::ORIGIN, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::ACCESS_CONTROL_REQUEST_HEADERS, + ]) } diff --git a/tower-http/src/cors/vary.rs b/tower-http/src/cors/vary.rs new file mode 100644 index 00000000..b1dddc36 --- /dev/null +++ b/tower-http/src/cors/vary.rs @@ -0,0 +1,51 @@ +use std::array; + +use http::{header::HeaderName, HeaderValue}; + +use super::preflight_request_headers; + +/// Holds configuration for how to set the [`Vary`][mdn] header. +/// +/// See [`CorsLayer::vary`] for more details. +/// +/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary +/// [`CorsLayer::vary`]: super::CorsLayer::vary +#[derive(Clone, Debug)] +pub struct Vary(Vec); + +impl Vary { + /// Set the list of header names to return as vary header values + /// + /// See [`CorsLayer::vary`] for more details. + /// + /// [`CorsLayer::vary`]: super::CorsLayer::vary + pub fn list(headers: I) -> Self + where + I: IntoIterator, + { + Self(headers.into_iter().map(Into::into).collect()) + } + + pub(super) fn values(&self) -> impl Iterator + '_ { + self.0.iter().cloned() + } +} + +impl Default for Vary { + fn default() -> Self { + Self::list(preflight_request_headers()) + } +} + +impl From<[HeaderName; N]> for Vary { + fn from(arr: [HeaderName; N]) -> Self { + #[allow(deprecated)] // Can be changed when MSRV >= 1.53 + Self::list(array::IntoIter::new(arr)) + } +} + +impl From> for Vary { + fn from(vec: Vec) -> Self { + Self::list(vec) + } +}