Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite the CORS middleware #237

Merged
merged 12 commits into from Apr 21, 2022
94 changes: 94 additions & 0 deletions tower-http/src/cors/allow_credentials.rs
@@ -0,0 +1,94 @@
use std::{fmt, sync::Arc};

use http::{
header::{self, HeaderName, HeaderValue},
request::Parts as RequestParts,
};

/// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header.
///
/// See [`CorsLayer::allow_credentials`] for more details.
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials
#[derive(Clone, Default)]
#[must_use]
pub struct AllowCredentials(AllowCredentialsInner);

impl AllowCredentials {
/// Allow credentials for all requests
///
/// See [`CorsLayer::allow_credentials`] for more details.
///
/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials
pub fn yes() -> Self {
Self(AllowCredentialsInner::Yes)
}

/// Allow credentials for some requests, based on a given predicate
///
/// See [`CorsLayer::allow_credentials`] for more details.
///
/// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials
pub fn predicate<F>(f: F) -> Self
where
F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
{
Self(AllowCredentialsInner::Predicate(Arc::new(f)))
}

pub(super) fn is_true(&self) -> bool {
matches!(&self.0, AllowCredentialsInner::Yes)
}

pub(super) fn to_header(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
#[allow(clippy::declare_interior_mutable_const)]
const TRUE: HeaderValue = HeaderValue::from_static("true");

let allow_creds = match &self.0 {
AllowCredentialsInner::Yes => true,
AllowCredentialsInner::No => false,
AllowCredentialsInner::Predicate(c) => c(origin?, parts),
};

allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE))
}
}

impl From<bool> for AllowCredentials {
fn from(v: bool) -> Self {
match v {
true => Self(AllowCredentialsInner::Yes),
false => Self(AllowCredentialsInner::No),
}
}
}

impl fmt::Debug for AllowCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
AllowCredentialsInner::Yes => f.debug_tuple("Yes").finish(),
AllowCredentialsInner::No => f.debug_tuple("No").finish(),
AllowCredentialsInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
}
}
}

#[derive(Clone)]
enum AllowCredentialsInner {
Yes,
No,
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
}

impl Default for AllowCredentialsInner {
fn default() -> Self {
Self::No
}
}
112 changes: 112 additions & 0 deletions tower-http/src/cors/allow_headers.rs
@@ -0,0 +1,112 @@
use std::{array, fmt};

use http::{
header::{self, HeaderName, HeaderValue},
request::Parts as RequestParts,
};

use super::{separated_by_commas, Any, WILDCARD};

/// Holds configuration for how to set the [`Access-Control-Allow-Headers`][mdn] header.
///
/// See [`CorsLayer::allow_headers`] for more details.
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers
#[derive(Clone, Default)]
#[must_use]
pub struct AllowHeaders(AllowHeadersInner);

impl AllowHeaders {
/// Allow any headers by sending a wildcard (`*`)
///
/// See [`CorsLayer::allow_headers`] for more details.
///
/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers
pub fn any() -> Self {
Self(AllowHeadersInner::Const(Some(WILDCARD)))
}

/// Set multiple allowed headers
///
/// See [`CorsLayer::allow_headers`] for more details.
///
/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers
pub fn list<I>(headers: I) -> Self
where
I: IntoIterator<Item = HeaderName>,
{
Self(AllowHeadersInner::Const(separated_by_commas(
headers.into_iter().map(Into::into),
)))
}

/// Allow any headers, by mirroring the preflight [`Access-Control-Request-Headers`][mdn]
/// header.
///
/// See [`CorsLayer::allow_headers`] for more details.
///
/// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers
pub fn mirror_request() -> Self {
Self(AllowHeadersInner::MirrorRequest)
}

#[allow(clippy::borrow_interior_mutable_const)]
pub(super) fn is_wildcard(&self) -> bool {
matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD)
}

pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> {
let allow_headers = match &self.0 {
AllowHeadersInner::Const(v) => v.clone()?,
AllowHeadersInner::MirrorRequest => parts
.headers
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)?
.clone(),
};

Some((header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers))
}
}

impl fmt::Debug for AllowHeaders {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
AllowHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
AllowHeadersInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(),
}
}
}

impl From<Any> for AllowHeaders {
fn from(_: Any) -> Self {
Self::any()
}
}

impl<const N: usize> From<[HeaderName; N]> for AllowHeaders {
fn from(arr: [HeaderName; N]) -> Self {
#[allow(deprecated)] // Can be changed when MSRV >= 1.53
Self::list(array::IntoIter::new(arr))
}
}

impl From<Vec<HeaderName>> for AllowHeaders {
fn from(vec: Vec<HeaderName>) -> Self {
Self::list(vec)
}
}

#[derive(Clone)]
enum AllowHeadersInner {
Const(Option<HeaderValue>),
MirrorRequest,
}

impl Default for AllowHeadersInner {
fn default() -> Self {
Self::Const(None)
}
}
132 changes: 132 additions & 0 deletions tower-http/src/cors/allow_methods.rs
@@ -0,0 +1,132 @@
use std::{array, fmt};

use http::{
header::{self, HeaderName, HeaderValue},
request::Parts as RequestParts,
Method,
};

use super::{separated_by_commas, Any, WILDCARD};

/// Holds configuration for how to set the [`Access-Control-Allow-Methods`][mdn] header.
///
/// See [`CorsLayer::allow_methods`] for more details.
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods
#[derive(Clone, Default)]
#[must_use]
pub struct AllowMethods(AllowMethodsInner);

impl AllowMethods {
/// Allow any method by sending a wildcard (`*`)
///
/// See [`CorsLayer::allow_methods`] for more details.
///
/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods
pub fn any() -> Self {
Self(AllowMethodsInner::Const(Some(WILDCARD)))
}

/// Set a single allowed method
///
/// See [`CorsLayer::allow_methods`] for more details.
///
/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods
pub fn exact(method: Method) -> Self {
Self(AllowMethodsInner::Const(Some(
HeaderValue::from_str(method.as_str()).unwrap(),
)))
}

/// Set multiple allowed methods
///
/// See [`CorsLayer::allow_methods`] for more details.
///
/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods
pub fn list<I>(methods: I) -> Self
where
I: IntoIterator<Item = Method>,
{
Self(AllowMethodsInner::Const(separated_by_commas(
methods
.into_iter()
.map(|m| HeaderValue::from_str(m.as_str()).unwrap()),
)))
}

/// Allow any method, by mirroring the preflight [`Access-Control-Request-Method`][mdn]
/// header.
///
/// See [`CorsLayer::allow_methods`] for more details.
///
/// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods
///
/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Method
pub fn mirror_request() -> Self {
Self(AllowMethodsInner::MirrorRequest)
}

#[allow(clippy::borrow_interior_mutable_const)]
pub(super) fn is_wildcard(&self) -> bool {
matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD)
}

pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> {
let allow_methods = match &self.0 {
AllowMethodsInner::Const(v) => v.clone()?,
AllowMethodsInner::MirrorRequest => parts
.headers
.get(header::ACCESS_CONTROL_REQUEST_METHOD)?
.clone(),
};

Some((header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods))
}
}

impl fmt::Debug for AllowMethods {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
AllowMethodsInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
AllowMethodsInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(),
}
}
}

impl From<Any> for AllowMethods {
fn from(_: Any) -> Self {
Self::any()
}
}

impl From<Method> for AllowMethods {
fn from(method: Method) -> Self {
Self::exact(method)
}
}

impl<const N: usize> From<[Method; N]> for AllowMethods {
fn from(arr: [Method; N]) -> Self {
#[allow(deprecated)] // Can be changed when MSRV >= 1.53
Self::list(array::IntoIter::new(arr))
}
}

impl From<Vec<Method>> for AllowMethods {
fn from(vec: Vec<Method>) -> Self {
Self::list(vec)
}
}

#[derive(Clone)]
enum AllowMethodsInner {
Const(Option<HeaderValue>),
MirrorRequest,
}

impl Default for AllowMethodsInner {
fn default() -> Self {
Self::Const(None)
}
}