Skip to content

Commit

Permalink
Add DefaultBodyLimit::max to change the body size limit (#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Sep 19, 2022
1 parent 7105805 commit de9909d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 19 deletions.
4 changes: 3 additions & 1 deletion axum-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])

[#1397]: https://github.com/tokio-rs/axum/pull/1397

# 0.3.0-rc.2 (10. September, 2022)

Expand Down
61 changes: 52 additions & 9 deletions axum-core/src/extract/default_body_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ use tower_layer::Layer;
/// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DefaultBodyLimit;
pub struct DefaultBodyLimit {
kind: DefaultBodyLimitKind,
}

#[derive(Debug, Clone, Copy)]
pub(crate) enum DefaultBodyLimitKind {
Disable,
Limit(usize),
}

impl DefaultBodyLimit {
/// Disable the default request body limit.
Expand Down Expand Up @@ -53,30 +60,66 @@ impl DefaultBodyLimit {
/// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html
pub fn disable() -> Self {
Self
Self {
kind: DefaultBodyLimitKind::Disable,
}
}

/// Set the default request body limit.
///
/// By default the limit of request body sizes that [`Bytes::from_request`] (and other
/// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method
/// can be used to change that limit.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// routing::get,
/// body::{Bytes, Body},
/// extract::DefaultBodyLimit,
/// };
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Replace the default of 2MB with 1024 bytes.
/// .layer(DefaultBodyLimit::max(1024));
/// ```
///
/// [`Bytes::from_request`]: bytes::Bytes
/// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html
pub fn max(limit: usize) -> Self {
Self {
kind: DefaultBodyLimitKind::Limit(limit),
}
}
}

impl<S> Layer<S> for DefaultBodyLimit {
type Service = DefaultBodyLimitService<S>;

fn layer(&self, inner: S) -> Self::Service {
DefaultBodyLimitService { inner }
DefaultBodyLimitService {
inner,
kind: self.kind,
}
}
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct DefaultBodyLimitDisabled;

mod private {
use super::DefaultBodyLimitDisabled;
use super::DefaultBodyLimitKind;
use http::Request;
use std::task::Context;
use tower_service::Service;

#[derive(Debug, Clone, Copy)]
pub struct DefaultBodyLimitService<S> {
pub(super) inner: S,
pub(super) kind: DefaultBodyLimitKind,
}

impl<B, S> Service<Request<B>> for DefaultBodyLimitService<S>
Expand All @@ -94,7 +137,7 @@ mod private {

#[inline]
fn call(&mut self, mut req: Request<B>) -> Self::Future {
req.extensions_mut().insert(DefaultBodyLimitDisabled);
req.extensions_mut().insert(self.kind);
self.inner.call(req)
}
}
Expand Down
26 changes: 17 additions & 9 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
default_body_limit::DefaultBodyLimitDisabled, rejection::*, FromRequest, FromRequestParts,
default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts,
};
use crate::BoxError;
use async_trait::async_trait;
Expand Down Expand Up @@ -88,15 +88,23 @@ where
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb

let bytes = if req.extensions().get::<DefaultBodyLimitDisabled>().is_some() {
crate::body::to_bytes(req.into_body())
let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied();
let bytes = match limit_kind {
Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body())
.await
.map_err(FailedToBufferBody::from_err)?
} else {
let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
.map_err(FailedToBufferBody::from_err)?,
Some(DefaultBodyLimitKind::Limit(limit)) => {
let body = http_body::Limited::new(req.into_body(), limit);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
None => {
let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
};

Ok(bytes)
Expand Down
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
rejections, instead of `422 Unprocessable Entity` ([#1387])
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])

[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387
[#1389]: https://github.com/tokio-rs/axum/pull/1389
[#1397]: https://github.com/tokio-rs/axum/pull/1397

# 0.6.0-rc.2 (10. September, 2022)

Expand Down
25 changes: 25 additions & 0 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,31 @@ async fn limited_body_with_content_length() {
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[tokio::test]
async fn changing_the_default_limit() {
let new_limit = 2;

let app = Router::new()
.route("/", post(|_: Bytes| async {}))
.layer(DefaultBodyLimit::max(new_limit));

let client = TestClient::new(app);

let res = client
.post("/")
.body(Body::from("a".repeat(new_limit)))
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);

let res = client
.post("/")
.body(Body::from("a".repeat(new_limit + 1)))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[tokio::test]
async fn limited_body_with_streaming_body() {
const LIMIT: usize = 3;
Expand Down

0 comments on commit de9909d

Please sign in to comment.