Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add layer that limits body size #271

Merged
merged 19 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bytes = "1"
futures-core = "0.3"
futures-util = { version = "0.3.14", default_features = false, features = [] }
http = "0.2.2"
http-body = "0.4.1"
http-body = "0.4.5"
pin-project-lite = "0.2.7"
tower-layer = "0.3"
tower-service = "0.3"
Expand Down Expand Up @@ -62,6 +62,7 @@ full = [
"decompression-full",
"follow-redirect",
"fs",
"limit",
"map-request-body",
"map-response-body",
"metrics",
Expand All @@ -82,6 +83,7 @@ catch-panic = ["tracing", "futures-util/std"]
cors = []
follow-redirect = ["iri-string", "tower/util"]
fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"]
limit = []
map-request-body = []
map-response-body = []
metrics = ["tokio/time"]
Expand Down
20 changes: 20 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
) -> ServiceBuilder<
Stack<crate::catch_panic::CatchPanicLayer<crate::catch_panic::DefaultResponseForPanic>, L>,
>;

/// Intercept requests with over-sized payloads and convert them into
/// `413 Payload Too Large` responses.
///
/// See [`tower_http::limit`] for more details.
///
/// [`tower_http::limit`]: crate::limit
#[cfg(feature = "limit")]
fn length_limit<B>(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::LengthLimitedLayer<B>, L>>;
}

impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -558,4 +570,12 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
> {
self.layer(crate::catch_panic::CatchPanicLayer::new())
}

#[cfg(feature = "limit")]
fn length_limit<B>(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::LengthLimitedLayer<B>, L>> {
self.layer(crate::limit::LengthLimitedLayer::new(limit))
}
}
3 changes: 3 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ pub mod trace;
#[cfg(feature = "follow-redirect")]
pub mod follow_redirect;

#[cfg(feature = "limit")]
pub mod limit;

#[cfg(feature = "metrics")]
pub mod metrics;

Expand Down
246 changes: 246 additions & 0 deletions tower-http/src/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
//! Imposes a length limit on request bodies.
//!
//! # Example
//!
//! ```rust
//! use bytes::Bytes;
//! use http::{Request, Response, StatusCode};
//! use std::convert::Infallible;
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//! use tower_http::limit::LengthLimitedLayer;
//! use hyper::Body;
//! use http_body::Limited;
//! use tower_http::BoxError;
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), BoxError> {
//! async fn handle(req: Request<Limited<Body>>) -> Result<Response<Body>, BoxError>
//! {
//! hyper::body::to_bytes(req.into_body()).await?;
//! Ok(Response::new(Body::empty()))
//! }
//!
//! let mut svc = ServiceBuilder::new()
//! // Limit incoming requests to 4096 bytes.
//! .layer(LengthLimitedLayer::new(4096))
//! .service_fn(handle);
//!
//! fn test_svc<S: Service<Request<Body>>>(s: &S) {}
//! test_svc(&svc);
//!
//! // Call the service.
//! let request = Request::new(Body::empty());
//!
//! let response = svc.ready().await?.call(request).await?;
//!
//! assert_eq!(response.status(), 200);
//!
//! // Call the service with a body that is too large.
//! let request = Request::new(Body::from(Bytes::from(vec![0u8; 4097])));
//!
//! let response = svc.ready().await?.call(request).await?;
//!
//! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
//!
//! #
//! # Ok(())
//! # }
//! ```

use crate::BoxError;
use bytes::Bytes;
use http::{HeaderValue, Request, Response, StatusCode};
use http_body::combinators::UnsyncBoxBody;
use http_body::{Body, Full, LengthLimitError, Limited};
use pin_project_lite::pin_project;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::{
any, fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies the [`LengthLimit`] middleware that intercepts requests
/// with body lengths greater than the configured limit and converts them into
/// `413 Payload Too Large` responses.
///
/// See the [module docs](self) for an example.
pub struct LengthLimitedLayer<B> {
limit: usize,
_ty: PhantomData<fn() -> B>,
}

impl<B> LengthLimitedLayer<B> {
/// Create a new `LengthLimitedLayer` with the given body length limit.
pub fn new(limit: usize) -> Self {
Self {
limit,
_ty: PhantomData,
}
}
}

impl<B> Clone for LengthLimitedLayer<B> {
fn clone(&self) -> Self {
Self {
limit: self.limit,
_ty: PhantomData,
}
}
}

impl<B> Copy for LengthLimitedLayer<B> {}

impl<B> fmt::Debug for LengthLimitedLayer<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LengthLimitedLayer")
.field("body", &any::type_name::<B>())
.field("limit", &self.limit)
.finish()
}
}

impl<B, S> Layer<S> for LengthLimitedLayer<B> {
type Service = LengthLimited<S, B>;

fn layer(&self, inner: S) -> Self::Service {
LengthLimited {
inner,
limit: self.limit,
_ty: PhantomData,
}
}
}

/// Middleware that intercepts requests with body lengths greater than the
/// configured limit and converts them into `413 Payload Too Large` responses.
///
/// See the [module docs](self) for an example.
pub struct LengthLimited<S, B> {
neoeinstein marked this conversation as resolved.
Show resolved Hide resolved
inner: S,
limit: usize,
_ty: PhantomData<fn() -> B>,
neoeinstein marked this conversation as resolved.
Show resolved Hide resolved
}

impl<S, B> LengthLimited<S, B> {
define_inner_service_accessors!();

/// Create a new `LengthLimited` with the given body length limit.
pub fn new(inner: S, limit: usize) -> Self {
Self {
inner,
limit,
_ty: PhantomData,
}
}
}

impl<S, B> Clone for LengthLimited<S, B>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
limit: self.limit,
_ty: PhantomData,
}
}
}

impl<S, B> Copy for LengthLimited<S, B> where S: Copy {}

impl<S, B> fmt::Debug for LengthLimited<S, B>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LengthLimited")
.field("inner", &self.inner)
.field("service", &format_args!("{}", any::type_name::<B>()))
neoeinstein marked this conversation as resolved.
Show resolved Hide resolved
.finish()
}
}

impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for LengthLimited<S, ReqBody>
where
S: Service<Request<Limited<ReqBody>>, Response = Response<ResBody>>,
S::Error: Into<BoxError>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response<UnsyncBoxBody<Bytes, BoxError>>;
type Error = BoxError;
type Future = ResponseFuture<S::Future>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|e| e.into())
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let (parts, body) = req.into_parts();
let body = Limited::new(body, self.limit);
let req = Request::from_parts(parts, body);

ResponseFuture {
future: self.inner.call(req),
}
}
neoeinstein marked this conversation as resolved.
Show resolved Hide resolved
}

pin_project! {
/// Response future for [`LengthLimit`].
pub struct ResponseFuture<F> {
#[pin]
future: F,
}
}

impl<F, ResBody, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
E: Into<BoxError>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Output = Result<Response<UnsyncBoxBody<Bytes, BoxError>>, BoxError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(data)) => {
let (parts, body) = data.into_parts();
let body = body.map_err(|err| err.into()).boxed_unsync();
let resp = Response::from_parts(parts, body);

Poll::Ready(Ok(resp))
}
Poll::Ready(Err(err)) => {
let err = err.into();
if let Some(_) = err.downcast_ref::<LengthLimitError>() {
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
let mut res = Response::new(
Full::from("length limit exceeded")
.map_err(|err| err.into())
.boxed_unsync(),
);
*res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;

#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue =
HeaderValue::from_static("text/plain; charset=utf-8");
res.headers_mut()
.insert(http::header::CONTENT_TYPE, TEXT_PLAIN);

Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
}
}
}