Skip to content

Commit

Permalink
Fix Handler::with_state not working if request body was changed via…
Browse files Browse the repository at this point in the history
… layer (#1536)

Previously

```rust
handler.layer(RequestBodyLimitLayer::new(...)).with_state(...)
```

didn't work because we required the same request body all the way
through.
  • Loading branch information
davidpdrsn committed Nov 18, 2022
1 parent b1f894a commit 2e8a7e5
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {hand
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
Expand All @@ -28,7 +28,7 @@ error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Futu
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `MethodRouter::<S, B>::post`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
Expand Down
2 changes: 1 addition & 1 deletion axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ serde = "1.0"
sync_wrapper = "0.1.1"
tower = { version = "0.4.13", default-features = false, features = ["util"] }
tower-http = { version = "0.3.0", features = ["util", "map-response-body"] }
tower-layer = "0.3"
tower-layer = "0.3.2"
tower-service = "0.3"

# optional dependencies
Expand Down
57 changes: 43 additions & 14 deletions axum/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
fn layer<L>(self, layer: L) -> Layered<L, Self, T, S, B>
fn layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody>
where
L: Layer<HandlerService<Self, T, S, B>> + Clone,
L::Service: Service<Request<NewReqBody>>,
{
Layered {
layer,
Expand Down Expand Up @@ -220,13 +221,13 @@ all_the_tuples!(impl_handler);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
///
/// Created with [`Handler::layer`]. See that method for more details.
pub struct Layered<L, H, T, S, B> {
pub struct Layered<L, H, T, S, B, B2> {
layer: L,
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
_marker: PhantomData<fn() -> (T, S, B, B2)>,
}

impl<L, H, T, S, B> fmt::Debug for Layered<L, H, T, S, B>
impl<L, H, T, S, B, B2> fmt::Debug for Layered<L, H, T, S, B, B2>
where
L: fmt::Debug,
{
Expand All @@ -237,7 +238,7 @@ where
}
}

impl<L, H, T, S, B> Clone for Layered<L, H, T, S, B>
impl<L, H, T, S, B, B2> Clone for Layered<L, H, T, S, B, B2>
where
L: Clone,
H: Clone,
Expand All @@ -251,20 +252,21 @@ where
}
}

impl<H, S, T, B, L> Handler<T, S, B> for Layered<L, H, T, S, B>
impl<H, S, T, L, B, B2> Handler<T, S, B2> for Layered<L, H, T, S, B, B2>
where
L: Layer<HandlerService<H, T, S, B>> + Clone + Send + 'static,
H: Handler<T, S, B>,
L::Service: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse,
<L::Service as Service<Request<B>>>::Future: Send,
L::Service: Service<Request<B2>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B2>>>::Response: IntoResponse,
<L::Service as Service<Request<B2>>>::Future: Send,
T: 'static,
S: 'static,
B: Send + 'static,
B2: Send + 'static,
{
type Future = future::LayeredFuture<B, L::Service>;
type Future = future::LayeredFuture<B2, L::Service>;

fn call(self, req: Request<B>, state: S) -> Self::Future {
fn call(self, req: Request<B2>, state: S) -> Self::Future {
use futures_util::future::{FutureExt, Map};

let svc = self.handler.with_state(state);
Expand All @@ -274,8 +276,8 @@ where
_,
fn(
Result<
<L::Service as Service<Request<B>>>::Response,
<L::Service as Service<Request<B>>>::Error,
<L::Service as Service<Request<B2>>>::Response,
<L::Service as Service<Request<B2>>>::Error,
>,
) -> _,
> = svc.oneshot(req).map(|result| match result {
Expand Down Expand Up @@ -338,8 +340,14 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use crate::{body, extract::State, test_helpers::*};
use http::StatusCode;
use std::time::Duration;
use tower_http::{
compression::CompressionLayer, limit::RequestBodyLimitLayer,
map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer,
timeout::TimeoutLayer,
};

#[tokio::test]
async fn handler_into_service() {
Expand All @@ -353,4 +361,25 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "you said: hi there!");
}

#[tokio::test]
async fn with_layer_that_changes_request_body_and_state() {
async fn handle(State(state): State<&'static str>) -> &'static str {
state
}

let svc = handle
.layer((
RequestBodyLimitLayer::new(1024),
TimeoutLayer::new(Duration::from_secs(10)),
MapResponseBodyLayer::new(body::boxed),
CompressionLayer::new(),
))
.layer(MapRequestBodyLayer::new(body::boxed))
.with_state("foo");

let client = TestClient::from_service(svc);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}
}
1 change: 1 addition & 0 deletions examples/key-value-store/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ tower-http = { version = "0.3.0", features = [
"limit",
"trace",
] }
tower-layer = "0.3.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
12 changes: 7 additions & 5 deletions examples/key-value-store/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use std::{
use tower::{BoxError, ServiceBuilder};
use tower_http::{
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
trace::TraceLayer,
trace::TraceLayer, ServiceBuilderExt,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

Expand All @@ -50,10 +50,12 @@ async fn main() {
get(kv_get.layer(CompressionLayer::new()))
// But don't compress `kv_set`
.post_service(
ServiceBuilder::new()
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */))
.service(kv_set.with_state(Arc::clone(&shared_state))),
kv_set
.layer((
DefaultBodyLimit::disable(),
RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */),
))
.with_state(Arc::clone(&shared_state)),
),
)
.route("/keys", get(list_keys))
Expand Down

0 comments on commit 2e8a7e5

Please sign in to comment.