Skip to content

Commit

Permalink
Add DefaultBodyLimit::max to change the body size limit (#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored and jplatte committed Oct 19, 2022
1 parent b19cdab commit 59475f1
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 21 deletions.
4 changes: 3 additions & 1 deletion axum-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])

[#1397]: https://github.com/tokio-rs/axum/pull/1397

# 0.2.8 (10. September, 2022)

Expand Down
61 changes: 52 additions & 9 deletions axum-core/src/extract/default_body_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ use tower_layer::Layer;
/// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DefaultBodyLimit;
pub struct DefaultBodyLimit {
kind: DefaultBodyLimitKind,
}

#[derive(Debug, Clone, Copy)]
pub(crate) enum DefaultBodyLimitKind {
Disable,
Limit(usize),
}

impl DefaultBodyLimit {
/// Disable the default request body limit.
Expand Down Expand Up @@ -53,30 +60,66 @@ impl DefaultBodyLimit {
/// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html
pub fn disable() -> Self {
Self
Self {
kind: DefaultBodyLimitKind::Disable,
}
}

/// Set the default request body limit.
///
/// By default the limit of request body sizes that [`Bytes::from_request`] (and other
/// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method
/// can be used to change that limit.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// routing::get,
/// body::{Bytes, Body},
/// extract::DefaultBodyLimit,
/// };
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Replace the default of 2MB with 1024 bytes.
/// .layer(DefaultBodyLimit::max(1024));
/// ```
///
/// [`Bytes::from_request`]: bytes::Bytes
/// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html
pub fn max(limit: usize) -> Self {
Self {
kind: DefaultBodyLimitKind::Limit(limit),
}
}
}

impl<S> Layer<S> for DefaultBodyLimit {
type Service = DefaultBodyLimitService<S>;

fn layer(&self, inner: S) -> Self::Service {
DefaultBodyLimitService { inner }
DefaultBodyLimitService {
inner,
kind: self.kind,
}
}
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct DefaultBodyLimitDisabled;

mod private {
use super::DefaultBodyLimitDisabled;
use super::DefaultBodyLimitKind;
use http::Request;
use std::task::Context;
use tower_service::Service;

#[derive(Debug, Clone, Copy)]
pub struct DefaultBodyLimitService<S> {
pub(super) inner: S,
pub(super) kind: DefaultBodyLimitKind,
}

impl<B, S> Service<Request<B>> for DefaultBodyLimitService<S>
Expand All @@ -94,7 +137,7 @@ mod private {

#[inline]
fn call(&mut self, mut req: Request<B>) -> Self::Future {
req.extensions_mut().insert(DefaultBodyLimitDisabled);
req.extensions_mut().insert(self.kind);
self.inner.call(req)
}
}
Expand Down
28 changes: 17 additions & 11 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use super::{
default_body_limit::DefaultBodyLimitDisabled, rejection::*, FromRequest, RequestParts,
};
use super::{default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, RequestParts};
use crate::BoxError;
use async_trait::async_trait;
use bytes::Bytes;
Expand Down Expand Up @@ -100,15 +98,23 @@ where

let body = take_body(req)?;

let bytes = if req.extensions().get::<DefaultBodyLimitDisabled>().is_some() {
crate::body::to_bytes(body)
let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied();
let bytes = match limit_kind {
Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
} else {
let body = http_body::Limited::new(body, DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
.map_err(FailedToBufferBody::from_err)?,
Some(DefaultBodyLimitKind::Limit(limit)) => {
let body = http_body::Limited::new(body, limit);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
None => {
let body = http_body::Limited::new(body, DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
};

Ok(bytes)
Expand Down
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Annotate panicking functions with `#[track_caller]` so the error
message points to where the user added the invalid router, rather than
somewhere internally in axum ([#1248])
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])

[#1248]: https://github.com/tokio-rs/axum/pull/1248
[#1397]: https://github.com/tokio-rs/axum/pull/1397

# 0.5.16 (10. September, 2022)

Expand Down
25 changes: 25 additions & 0 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,31 @@ async fn limited_body_with_content_length() {
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[tokio::test]
async fn changing_the_default_limit() {
let new_limit = 2;

let app = Router::new()
.route("/", post(|_: Bytes| async {}))
.layer(DefaultBodyLimit::max(new_limit));

let client = TestClient::new(app);

let res = client
.post("/")
.body(Body::from("a".repeat(new_limit)))
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);

let res = client
.post("/")
.body(Body::from("a".repeat(new_limit + 1)))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[tokio::test]
async fn limited_body_with_streaming_body() {
const LIMIT: usize = 3;
Expand Down

0 comments on commit 59475f1

Please sign in to comment.