diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index a2d3e0c783..536ea88863 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **breaking:** Added default limit to how much data `Bytes::from_request` will + consume. Previously it would attempt to consume the entire request body + without checking its length. This meant if a malicious peer sent an large (or + infinite) request body your server might run out of memory and crash. + + The default limit is at 2 MB and can be disabled by adding the new + `DefaultBodyLimit::disable()` middleware. See its documentation for more + details. + + This also applies to `String` which used `Bytes::from_request` internally. + + ([#1346]) + +[#1346]: https://github.com/tokio-rs/axum/pull/1346 # 0.2.7 (10. July, 2022) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index 110061964c..99fcefecfb 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -17,9 +17,12 @@ futures-util = { version = "0.3", default-features = false, features = ["alloc"] http = "0.2.7" http-body = "0.4.5" mime = "0.3.16" +tower-layer = "0.3" +tower-service = "0.3" [dev-dependencies] axum = { path = "../axum", version = "0.5" } futures-util = "0.3" hyper = "0.14" tokio = { version = "1.0", features = ["macros"] } +tower-http = { version = "0.3.4", features = ["limit"] } diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs new file mode 100644 index 0000000000..7f12bc9c8e --- /dev/null +++ b/axum-core/src/extract/default_body_limit.rs @@ -0,0 +1,101 @@ +use self::private::DefaultBodyLimitService; +use tower_layer::Layer; + +/// Layer for configuring the default request body limit. +/// +/// For security reasons, [`Bytes`] will, by default, not accept bodies larger than 2MB. This also +/// applies to extractors that uses [`Bytes`] internally such as `String`, [`Json`], and [`Form`]. +/// +/// This middleware provides ways to configure that. +/// +/// Note that if an extractor consumes the body directly with [`Body::data`], or similar, the +/// default limit is _not_ applied. +/// +/// [`Body::data`]: http_body::Body::data +/// [`Bytes`]: bytes::Bytes +/// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html +/// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct DefaultBodyLimit; + +impl DefaultBodyLimit { + /// Disable the default request body limit. + /// + /// This must be used to receive bodies larger than the default limit of 2MB using [`Bytes`] or + /// an extractor built on it such as `String`, [`Json`], [`Form`]. + /// + /// Note that if you're accepting data from untrusted remotes it is recommend to add your own + /// limit such as [`tower_http::limit`]. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// Router, + /// routing::get, + /// body::{Bytes, Body}, + /// extract::DefaultBodyLimit, + /// }; + /// use tower_http::limit::RequestBodyLimitLayer; + /// use http_body::Limited; + /// + /// let app: Router> = Router::new() + /// .route("/", get(|body: Bytes| async {})) + /// // Disable the default limit + /// .layer(DefaultBodyLimit::disable()) + /// // Set a different limit + /// .layer(RequestBodyLimitLayer::new(10 * 1000 * 1000)); + /// ``` + /// + /// [`tower_http::limit`]: https://docs.rs/tower-http/0.3.4/tower_http/limit/index.html + /// [`Bytes`]: bytes::Bytes + /// [`Json`]: https://docs.rs/axum/0.5/axum/struct.Json.html + /// [`Form`]: https://docs.rs/axum/0.5/axum/struct.Form.html + pub fn disable() -> Self { + Self + } +} + +impl Layer for DefaultBodyLimit { + type Service = DefaultBodyLimitService; + + fn layer(&self, inner: S) -> Self::Service { + DefaultBodyLimitService { inner } + } +} + +#[derive(Copy, Clone, Debug)] +pub(crate) struct DefaultBodyLimitDisabled; + +mod private { + use super::DefaultBodyLimitDisabled; + use http::Request; + use std::task::Context; + use tower_service::Service; + + #[derive(Debug, Clone, Copy)] + pub struct DefaultBodyLimitService { + pub(super) inner: S, + } + + impl Service> for DefaultBodyLimitService + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + #[inline] + fn call(&mut self, mut req: Request) -> Self::Future { + req.extensions_mut().insert(DefaultBodyLimitDisabled); + self.inner.call(req) + } + } +} diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 2316633be5..5887029649 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -12,9 +12,12 @@ use std::convert::Infallible; pub mod rejection; +mod default_body_limit; mod request_parts; mod tuple; +pub use self::default_body_limit::DefaultBodyLimit; + /// Types that can be created from requests. /// /// See [`axum::extract`] for more details. diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 33383a7d8d..66753c1be4 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,4 +1,6 @@ -use super::{rejection::*, FromRequest, RequestParts}; +use super::{ + default_body_limit::DefaultBodyLimitDisabled, rejection::*, FromRequest, RequestParts, +}; use crate::BoxError; use async_trait::async_trait; use bytes::Bytes; @@ -92,11 +94,22 @@ where type Rejection = BytesRejection; async fn from_request(req: &mut RequestParts) -> Result { + // update docs in `axum-core/src/extract/default_body_limit.rs` and + // `axum/src/docs/extract.md` if this changes + const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb + let body = take_body(req)?; - let bytes = crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)?; + let bytes = if req.extensions().get::().is_some() { + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + } else { + let body = http_body::Limited::new(body, DEFAULT_LIMIT); + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + }; Ok(bytes) } @@ -112,14 +125,16 @@ where type Rejection = StringRejection; async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - - let bytes = crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? - .to_vec(); - - let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; + let bytes = Bytes::from_request(req).await.map_err(|err| match err { + BytesRejection::FailedToBufferBody(inner) => StringRejection::FailedToBufferBody(inner), + BytesRejection::BodyAlreadyExtracted(inner) => { + StringRejection::BodyAlreadyExtracted(inner) + } + })?; + + let string = std::str::from_utf8(&bytes) + .map_err(InvalidUtf8::from_err)? + .to_owned(); Ok(string) } diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 7945e8bb84..cc0a682242 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -7,10 +7,26 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let check_request_last_extractor = check_request_last_extractor(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); let check_multiple_body_extractors = check_multiple_body_extractors(&item_fn); - - let check_inputs_impls_from_request = check_inputs_impls_from_request(&item_fn, &attr.body_ty); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); - let check_future_send = check_future_send(&item_fn); + + // If the function is generic, we can't reliably check its inputs or whether the future it + // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. + let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { + let check_inputs_impls_from_request = + check_inputs_impls_from_request(&item_fn, &attr.body_ty); + let check_future_send = check_future_send(&item_fn); + + quote! { + #check_inputs_impls_from_request + #check_future_send + } + } else { + syn::Error::new_spanned( + &item_fn.sig.generics, + "`#[axum_macros::debug_handler]` doesn't support generic functions", + ) + .into_compile_error() + }; quote! { #item_fn @@ -18,9 +34,8 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { #check_request_last_extractor #check_path_extractor #check_multiple_body_extractors - #check_inputs_impls_from_request #check_output_impls_into_response - #check_future_send + #check_inputs_and_future_send } } @@ -153,14 +168,6 @@ fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream { } fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStream { - if !item_fn.sig.generics.params.is_empty() { - return syn::Error::new_spanned( - &item_fn.sig.generics, - "`#[axum_macros::debug_handler]` doesn't support generic functions", - ) - .into_compile_error(); - } - item_fn .sig .inputs diff --git a/axum-macros/tests/debug_handler/fail/generics.rs b/axum-macros/tests/debug_handler/fail/generics.rs index 310de31867..dd15076761 100644 --- a/axum-macros/tests/debug_handler/fail/generics.rs +++ b/axum-macros/tests/debug_handler/fail/generics.rs @@ -1,6 +1,6 @@ use axum_macros::debug_handler; #[debug_handler] -async fn handler() {} +async fn handler(extract: T) {} fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/generics.stderr b/axum-macros/tests/debug_handler/fail/generics.stderr index 52b705983e..4a96a0e3cd 100644 --- a/axum-macros/tests/debug_handler/fail/generics.stderr +++ b/axum-macros/tests/debug_handler/fail/generics.stderr @@ -1,13 +1,5 @@ error: `#[axum_macros::debug_handler]` doesn't support generic functions --> tests/debug_handler/fail/generics.rs:4:17 | -4 | async fn handler() {} +4 | async fn handler(extract: T) {} | ^^^ - -error[E0282]: type annotations needed - --> tests/debug_handler/fail/generics.rs:4:10 - | -4 | async fn handler() {} - | ----- ^^^^^^^ cannot infer type for type parameter `T` declared on the function `handler` - | | - | consider giving `future` a type diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 5af2df555a..ae117a4295 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +## Security + +- **breaking:** Added default limit to how much data `Bytes::from_request` will + consume. Previously it would attempt to consume the entire request body + without checking its length. This meant if a malicious peer sent an large (or + infinite) request body your server might run out of memory and crash. + + The default limit is at 2 MB and can be disabled by adding the new + `DefaultBodyLimit::disable()` middleware. See its documentation for more + details. + + This also applies to these extractors which used `Bytes::from_request` + internally: + - `Form` + - `Json` + - `String` + + ([#1346]) # 0.5.15 (9. August, 2022) diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 92784a65e5..ec2a1b8295 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -11,6 +11,7 @@ Types and traits for extracting data from requests. - [Accessing inner errors](#accessing-inner-errors) - [Defining custom extractors](#defining-custom-extractors) - [Accessing other extractors in `FromRequest` implementations](#accessing-other-extractors-in-fromrequest-implementations) +- [Request body limits](#request-body-limits) - [Request body extractors](#request-body-extractors) - [Running extractors from middleware](#running-extractors-from-middleware) @@ -505,6 +506,14 @@ let app = Router::new().route("/", get(handler)).layer(Extension(state)); # }; ``` +# Request body limits + +For security reasons, [`Bytes`] will, by default, not accept bodies larger than +2MB. This also applies to extractors that uses [`Bytes`] internally such as +`String`, [`Json`], and [`Form`]. + +For more details, including how to disable this limit, see [`DefaultBodyLimit`]. + # Request body extractors Most of the time your request body type will be [`body::Body`] (a re-export @@ -637,6 +646,7 @@ let app = Router::new().layer(middleware::from_fn(auth_middleware)); ``` [`body::Body`]: crate::body::Body +[`Bytes`]: crate::body::Bytes [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs [`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html [`Request`]: https://docs.rs/http/latest/http/struct.Request.html diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 41a255116a..b7bb8d2d9e 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -17,7 +17,7 @@ mod raw_query; mod request_parts; #[doc(inline)] -pub use axum_core::extract::{FromRequest, RequestParts}; +pub use axum_core::extract::{DefaultBodyLimit, FromRequest, RequestParts}; #[doc(inline)] #[allow(deprecated)] diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index b253bc5282..311126f4fe 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1,13 +1,14 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, - extract::{self, Path}, + extract::{self, DefaultBodyLimit, Path}, handler::Handler, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, test_helpers::*, BoxError, Json, Router, }; +use futures_util::stream::StreamExt; use http::{header::CONTENT_LENGTH, HeaderMap, Method, Request, Response, StatusCode, Uri}; use hyper::Body; use serde::Deserialize; @@ -700,6 +701,50 @@ async fn routes_must_start_with_slash() { TestClient::new(app); } +#[tokio::test] +async fn body_limited_by_default() { + let app = Router::new() + .route("/bytes", post(|_: Bytes| async {})) + .route("/string", post(|_: String| async {})) + .route("/json", post(|_: Json| async {})); + + let client = TestClient::new(app); + + for uri in ["/bytes", "/string", "/json"] { + println!("calling {}", uri); + + let stream = futures_util::stream::repeat("a".repeat(1000)).map(Ok::<_, hyper::Error>); + let body = Body::wrap_stream(stream); + + let res_future = client + .post(uri) + .header("content-type", "application/json") + .body(body) + .send(); + let res = tokio::time::timeout(Duration::from_secs(3), res_future) + .await + .expect("never got response"); + + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); + } +} + +#[tokio::test] +async fn disabling_the_default_limit() { + let app = Router::new() + .route("/", post(|_: Bytes| async {})) + .layer(DefaultBodyLimit::disable()); + + let client = TestClient::new(app); + + // `DEFAULT_LIMIT` is 2mb so make a body larger than that + let body = Body::from("a".repeat(3_000_000)); + + let res = client.post("/").body(body).send().await; + + assert_eq!(res.status(), StatusCode::OK); +} + #[tokio::test] async fn limited_body_with_content_length() { const LIMIT: usize = 3;