From 611c50ec8be6262888510ae6fece7f324e02f0d8 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 20 Sep 2022 10:13:06 +0200 Subject: [PATCH] Add `middleware::from_extractor_with_state` (#1396) Fixes https://github.com/tokio-rs/axum/issues/1373 --- axum/CHANGELOG.md | 3 + axum/src/middleware/from_extractor.rs | 125 +++++++++++++++++--------- axum/src/middleware/mod.rs | 5 +- 3 files changed, 92 insertions(+), 41 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9405265397..a1323b47f3 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -15,11 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389]) - **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString` rejections, instead of `422 Unprocessable Entity` ([#1387]) +- **added:** Add `middleware::from_extractor_with_state` and + `middleware::from_extractor_with_state_arc` ([#1396]) - **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1387]: https://github.com/tokio-rs/axum/pull/1387 [#1389]: https://github.com/tokio-rs/axum/pull/1389 +[#1396]: https://github.com/tokio-rs/axum/pull/1396 [#1397]: https://github.com/tokio-rs/axum/pull/1397 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 042c872068..6fd15f73af 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -10,6 +10,7 @@ use std::{ future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; @@ -90,8 +91,25 @@ use tower_service::Service; /// ``` /// /// [`Bytes`]: bytes::Bytes -pub fn from_extractor() -> FromExtractorLayer { - FromExtractorLayer(PhantomData) +pub fn from_extractor() -> FromExtractorLayer { + from_extractor_with_state(()) +} + +/// Create a middleware from an extractor with the given state. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +pub fn from_extractor_with_state(state: S) -> FromExtractorLayer { + from_extractor_with_state_arc(Arc::new(state)) +} + +/// Create a middleware from an extractor with the given [`Arc`]'ed state. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +pub fn from_extractor_with_state_arc(state: Arc) -> FromExtractorLayer { + FromExtractorLayer { + state, + _marker: PhantomData, + } } /// [`Layer`] that applies [`FromExtractor`] that runs an extractor and @@ -100,28 +118,39 @@ pub fn from_extractor() -> FromExtractorLayer { /// See [`from_extractor`] for more details. /// /// [`Layer`]: tower::Layer -pub struct FromExtractorLayer(PhantomData E>); +pub struct FromExtractorLayer { + state: Arc, + _marker: PhantomData E>, +} -impl Clone for FromExtractorLayer { +impl Clone for FromExtractorLayer { fn clone(&self) -> Self { - Self(PhantomData) + Self { + state: Arc::clone(&self.state), + _marker: PhantomData, + } } } -impl fmt::Debug for FromExtractorLayer { +impl fmt::Debug for FromExtractorLayer +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractorLayer") + .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } -impl Layer for FromExtractorLayer { - type Service = FromExtractor; +impl Layer for FromExtractorLayer { + type Service = FromExtractor; - fn layer(&self, inner: S) -> Self::Service { + fn layer(&self, inner: T) -> Self::Service { FromExtractor { inner, + state: Arc::clone(&self.state), _extractor: PhantomData, } } @@ -130,52 +159,57 @@ impl Layer for FromExtractorLayer { /// Middleware that runs an extractor and discards the value. /// /// See [`from_extractor`] for more details. -pub struct FromExtractor { - inner: S, +pub struct FromExtractor { + inner: T, + state: Arc, _extractor: PhantomData E>, } #[test] fn traits() { use crate::test_helpers::*; - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); } -impl Clone for FromExtractor +impl Clone for FromExtractor where - S: Clone, + T: Clone, { fn clone(&self) -> Self { Self { inner: self.inner.clone(), + state: Arc::clone(&self.state), _extractor: PhantomData, } } } -impl fmt::Debug for FromExtractor +impl fmt::Debug for FromExtractor where + T: fmt::Debug, S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractor") .field("inner", &self.inner) + .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } -impl Service> for FromExtractor +impl Service> for FromExtractor where - E: FromRequestParts<()> + 'static, + E: FromRequestParts + 'static, B: Default + Send + 'static, - S: Service> + Clone, - S::Response: IntoResponse, + T: Service> + Clone, + T::Response: IntoResponse, + S: Send + Sync + 'static, { type Response = Response; - type Error = S::Error; - type Future = ResponseFuture; + type Error = T::Error; + type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -183,9 +217,10 @@ where } fn call(&mut self, req: Request) -> Self::Future { + let state = Arc::clone(&self.state); let extract_future = Box::pin(async move { let (mut parts, body) = req.into_parts(); - let extracted = E::from_request_parts(&mut parts, &()).await; + let extracted = E::from_request_parts(&mut parts, &state).await; let req = Request::from_parts(parts, body); (req, extracted) }); @@ -202,39 +237,39 @@ where pin_project! { /// Response future for [`FromExtractor`]. #[allow(missing_debug_implementations)] - pub struct ResponseFuture + pub struct ResponseFuture where - E: FromRequestParts<()>, - S: Service>, + E: FromRequestParts, + T: Service>, { #[pin] - state: State, - svc: Option, + state: State, + svc: Option, } } pin_project! { #[project = StateProj] - enum State + enum State where - E: FromRequestParts<()>, - S: Service>, + E: FromRequestParts, + T: Service>, { Extracting { future: BoxFuture<'static, (Request, Result)>, }, - Call { #[pin] future: S::Future }, + Call { #[pin] future: T::Future }, } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - E: FromRequestParts<()>, - S: Service>, - S::Response: IntoResponse, + E: FromRequestParts, + T: Service>, + T::Response: IntoResponse, B: Default, { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { @@ -272,29 +307,35 @@ where mod tests { use super::*; use crate::{handler::Handler, routing::get, test_helpers::*, Router}; + use axum_core::extract::FromRef; use http::{header, request::Parts, StatusCode}; #[tokio::test] async fn test_from_extractor() { + #[derive(Clone)] + struct Secret(&'static str); + struct RequireAuth; #[async_trait::async_trait] impl FromRequestParts for RequireAuth where S: Send + Sync, + Secret: FromRef, { type Rejection = StatusCode; async fn from_request_parts( parts: &mut Parts, - _state: &S, + state: &S, ) -> Result { + let Secret(secret) = Secret::from_ref(state); if let Some(auth) = parts .headers .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) { - if auth == "secret" { + if auth == secret { return Ok(Self); } } @@ -305,7 +346,11 @@ mod tests { async fn handler() {} - let app = Router::new().route("/", get(handler.layer(from_extractor::()))); + let state = Secret("secret"); + let app = Router::new().route( + "/", + get(handler.layer(from_extractor_with_state::(state))), + ); let client = TestClient::new(app); diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 15132da4d3..6dde14894e 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -5,7 +5,10 @@ mod from_extractor; mod from_fn; -pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer}; +pub use self::from_extractor::{ + from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor, + FromExtractorLayer, +}; pub use self::from_fn::{ from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next, };