diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 6ac52260c4..3433529090 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -9,13 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Support resolving host name via `Forwarded` header in `Host` extractor ([#1078]) +- **added:** Support running extractors from `middleware::from_fn` functions ([#1088]) - **breaking:** Allow `Error: Into` for `Route::{layer, route_layer}` ([#924]) - **breaking:** Remove `extractor_middleware` which was previously deprecated. Use `axum::middleware::from_extractor` instead ([#1077]) [#924]: https://github.com/tokio-rs/axum/pull/924 -[#1078]: https://github.com/tokio-rs/axum/pull/1078 [#1077]: https://github.com/tokio-rs/axum/pull/1077 +[#1078]: https://github.com/tokio-rs/axum/pull/1078 +[#1088]: https://github.com/tokio-rs/axum/pull/1088 # 0.5.7 (08. June, 2022) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 9f9563b590..8e46fa690f 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -3,13 +3,15 @@ use crate::{ response::{IntoResponse, Response}, BoxError, }; +use axum_core::extract::{FromRequest, RequestParts}; +use futures_util::future::BoxFuture; use http::Request; -use pin_project_lite::pin_project; use std::{ any::type_name, convert::Infallible, fmt, future::Future, + marker::PhantomData, pin::Pin, task::{Context, Poll}, }; @@ -23,8 +25,8 @@ use tower_service::Service; /// `from_fn` requires the function given to /// /// 1. Be an `async fn`. -/// 2. Take [`Request`](http::Request) as the first argument. -/// 3. Take [`Next`](Next) as the second argument. +/// 2. Take one or more [extractors] as the first arguments. +/// 3. Take [`Next`](Next) as the final argument. /// 4. Return something that implements [`IntoResponse`]. /// /// # Example @@ -62,6 +64,37 @@ use tower_service::Service; /// # let app: Router = app; /// ``` /// +/// # Running extractors +/// +/// ```rust +/// use axum::{ +/// Router, +/// extract::{TypedHeader, Query}, +/// headers::authorization::{Authorization, Bearer}, +/// http::Request, +/// middleware::{self, Next}, +/// response::Response, +/// routing::get, +/// }; +/// use std::collections::HashMap; +/// +/// async fn my_middleware( +/// TypedHeader(auth): TypedHeader>, +/// Query(query_params): Query>, +/// req: Request, +/// next: Next, +/// ) -> Response { +/// // do something with `auth` and `query_params`... +/// +/// next.run(req).await +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .route_layer(middleware::from_fn(my_middleware)); +/// # let app: Router = app; +/// ``` +/// /// # Passing state /// /// State can be passed to the function like so: @@ -114,11 +147,10 @@ use tower_service::Service; /// struct State { /* ... */ } /// /// async fn my_middleware( +/// Extension(state): Extension, /// req: Request, /// next: Next, /// ) -> Response { -/// let state: &State = req.extensions().get().unwrap(); -/// /// // ... /// # ().into_response() /// } @@ -134,8 +166,13 @@ use tower_service::Service; /// ); /// # let app: Router = app; /// ``` -pub fn from_fn(f: F) -> FromFnLayer { - FromFnLayer { f } +/// +/// [extractors]: crate::extract::FromRequest +pub fn from_fn(f: F) -> FromFnLayer { + FromFnLayer { + f, + _extractor: PhantomData, + } } /// A [`tower::Layer`] from an async function. @@ -143,26 +180,41 @@ pub fn from_fn(f: F) -> FromFnLayer { /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// /// Created with [`from_fn`]. See that function for more details. -#[derive(Clone, Copy)] -pub struct FromFnLayer { +pub struct FromFnLayer { f: F, + _extractor: PhantomData T>, } -impl Layer for FromFnLayer +impl Clone for FromFnLayer where F: Clone, { - type Service = FromFn; + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + _extractor: self._extractor, + } + } +} + +impl Copy for FromFnLayer where F: Copy {} + +impl Layer for FromFnLayer +where + F: Clone, +{ + type Service = FromFn; fn layer(&self, inner: S) -> Self::Service { FromFn { f: self.f.clone(), inner, + _extractor: PhantomData, } } } -impl fmt::Debug for FromFnLayer { +impl fmt::Debug for FromFnLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") // Write out the type name, without quoting it as `&type_name::()` would @@ -174,50 +226,94 @@ impl fmt::Debug for FromFnLayer { /// A middleware created from an async function. /// /// Created with [`from_fn`]. See that function for more details. -#[derive(Clone, Copy)] -pub struct FromFn { +pub struct FromFn { f: F, inner: S, + _extractor: PhantomData T>, } -impl Service> for FromFn +impl Clone for FromFn where - F: FnMut(Request, Next) -> Fut, - Fut: Future, - Out: IntoResponse, - S: Service, Response = Response, Error = Infallible> - + Clone - + Send - + 'static, - S::Future: Send + 'static, - ResBody: HttpBody + Send + 'static, - ResBody::Error: Into, + F: Clone, + S: Clone, { - type Response = Response; - type Error = Infallible; - type Future = ResponseFuture; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + inner: self.inner.clone(), + _extractor: self._extractor, + } } +} + +impl Copy for FromFn +where + F: Copy, + S: Copy, +{ +} + +macro_rules! impl_service { + ( $($ty:ident),* $(,)? ) => { + #[allow(non_snake_case)] + impl Service> for FromFn + where + F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, + $( $ty: FromRequest + Send, )* + Fut: Future + Send + 'static, + Out: IntoResponse + 'static, + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + ReqBody: Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Error: Into, + { + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; - fn call(&mut self, req: Request) -> Self::Future { - let not_ready_inner = self.inner.clone(); - let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } - let inner = ServiceBuilder::new() - .boxed_clone() - .map_response_body(body::boxed) - .service(ready_inner); - let next = Next { inner }; + fn call(&mut self, req: Request) -> Self::Future { + let not_ready_inner = self.inner.clone(); + let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); - ResponseFuture { - inner: (self.f)(req, next), + let mut f = self.f.clone(); + + let future = Box::pin(async move { + let mut parts = RequestParts::new(req); + $( + let $ty = match $ty::from_request(&mut parts).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + )* + + let inner = ServiceBuilder::new() + .boxed_clone() + .map_response_body(body::boxed) + .service(ready_inner); + let next = Next { inner }; + + f($($ty),*, next).await.into_response() + }); + + ResponseFuture { + inner: future + } + } } - } + }; } -impl fmt::Debug for FromFn +all_the_tuples!(impl_service); + +impl fmt::Debug for FromFn where S: fmt::Debug, { @@ -252,27 +348,22 @@ impl fmt::Debug for Next { } } -pin_project! { - /// Response future for [`FromFn`]. - pub struct ResponseFuture { - #[pin] - inner: F, - } +/// Response future for [`FromFn`]. +pub struct ResponseFuture { + inner: BoxFuture<'static, Response>, } -impl Future for ResponseFuture -where - F: Future, - Out: IntoResponse, -{ +impl Future for ResponseFuture { type Output = Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project() - .inner - .poll(cx) - .map(IntoResponse::into_response) - .map(Ok) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(cx).map(Ok) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() } }