Skip to content

Commit

Permalink
Add middleware::{from_fn_with_state, from_fn_with_state_arc} (#1342)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Aug 31, 2022
1 parent 3f92f7d commit 4c9edb4
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 114 deletions.
7 changes: 7 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
without any routes will now result in a panic. Previously, this just did
nothing. [#1327]

## Middleware

- **added**: Add `middleware::from_fn_with_state` and
`middleware::from_fn_with_state_arc` to enable running extractors that require
state ([#1342])

[#1327]: https://github.com/tokio-rs/axum/pull/1327
[#1342]: https://github.com/tokio-rs/axum/pull/1342

# 0.6.0-rc.1 (23. August, 2022)

Expand Down
44 changes: 7 additions & 37 deletions axum/src/docs/middleware.md
Expand Up @@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.

# Accessing state in middleware

Handlers can access state using the [`State`] extractor but this isn't available
to middleware. Instead you have to pass the state directly to middleware using
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
fields (if you're implementing a [`tower::Layer`])
How to make state available to middleware depends on how the middleware is
written.

## Accessing state in `axum::middleware::from_fn`

```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};

#[derive(Clone)]
struct AppState {}

async fn my_middleware<B>(
state: AppState,
req: Request<B>,
next: Next<B>,
) -> Response {
next.run(req).await
}

async fn handler(_: State<AppState>) {}

let state = AppState {};

let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(middleware::from_fn(move |req, next| {
my_middleware(state.clone(), req, next)
}));
# let _: Router<_> = app;
```
Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).

## Accessing state in custom `tower::Layer`s

Expand Down Expand Up @@ -482,7 +449,10 @@ where
}

fn call(&mut self, req: Request<B>) -> Self::Future {
// do something with `self.state`
// Do something with `self.state`.
//
// See `axum::RequestExt` for how to run extractors directly from
// a `Request`.

self.inner.call(req)
}
Expand Down
132 changes: 56 additions & 76 deletions axum/src/middleware/from_fn.rs
Expand Up @@ -9,6 +9,7 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{util::BoxCloneService, ServiceBuilder};
Expand Down Expand Up @@ -90,82 +91,57 @@ use tower_service::Service;
/// # let app: Router = app;
/// ```
///
/// # Passing state
///
/// State can be passed to the function like so:
///
/// ```rust
/// use axum::{
/// Router,
/// http::{Request, StatusCode},
/// routing::get,
/// response::{IntoResponse, Response},
/// middleware::{self, Next}
/// };
///
/// #[derive(Clone)]
/// struct State { /* ... */ }
///
/// async fn my_middleware<B>(
/// req: Request<B>,
/// next: Next<B>,
/// state: State,
/// ) -> Response {
/// // ...
/// # ().into_response()
/// }
///
/// let state = State { /* ... */ };
/// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
from_fn_with_state((), f)
}

/// Create a middleware from an async function with the given state.
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn(move |req, next| {
/// my_middleware(req, next, state.clone())
/// }));
/// # let app: Router = app;
/// ```
/// See [`State`](crate::extract::State) for more details about accessing state.
///
/// Or via extensions:
/// # Example
///
/// ```rust
/// use axum::{
/// Router,
/// extract::Extension,
/// http::{Request, StatusCode},
/// routing::get,
/// response::{IntoResponse, Response},
/// middleware::{self, Next},
/// extract::State,
/// };
/// use tower::ServiceBuilder;
///
/// #[derive(Clone)]
/// struct State { /* ... */ }
/// struct AppState { /* ... */ }
///
/// async fn my_middleware<B>(
/// Extension(state): Extension<State>,
/// State(state): State<AppState>,
/// req: Request<B>,
/// next: Next<B>,
/// ) -> Response {
/// // ...
/// # ().into_response()
/// }
///
/// let state = State { /* ... */ };
/// let state = AppState { /* ... */ };
///
/// let app = Router::new()
/// let app = Router::with_state(state.clone())
/// .route("/", get(|| async { /* ... */ }))
/// .layer(
/// ServiceBuilder::new()
/// .layer(Extension(state))
/// .layer(middleware::from_fn(my_middleware)),
/// );
/// # let app: Router = app;
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
from_fn_with_state_arc(Arc::new(state), f)
}

/// Create a middleware from an async function with the given [`Arc`]'ed state.
///
/// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer {
f,
state,
_extractor: PhantomData,
}
}
Expand All @@ -175,98 +151,99 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
///
/// Created with [`from_fn`]. See that function for more details.
pub struct FromFnLayer<F, T> {
pub struct FromFnLayer<F, S, T> {
f: F,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, T> Clone for FromFnLayer<F, T>
impl<F, S, T> Clone for FromFnLayer<F, S, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}

impl<S, F, T> Layer<S> for FromFnLayer<F, T>
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
where
F: Clone,
{
type Service = FromFn<F, S, T>;
type Service = FromFn<F, S, I, T>;

fn layer(&self, inner: S) -> Self::Service {
fn layer(&self, inner: I) -> Self::Service {
FromFn {
f: self.f.clone(),
state: Arc::clone(&self.state),
inner,
_extractor: PhantomData,
}
}
}

impl<F, T> fmt::Debug for FromFnLayer<F, T> {
impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would
.field("f", &format_args!("{}", type_name::<F>()))
.field("state", &self.state)
.finish()
}
}

/// A middleware created from an async function.
///
/// Created with [`from_fn`]. See that function for more details.
pub struct FromFn<F, S, T> {
pub struct FromFn<F, S, I, T> {
f: F,
inner: S,
inner: I,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, S, T> Clone for FromFn<F, S, T>
impl<F, S, I, T> Clone for FromFn<F, S, I, T>
where
F: Clone,
S: Clone,
I: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

impl<F, S, T> Copy for FromFn<F, S, T>
where
F: Copy,
S: Copy,
{
}

macro_rules! impl_service {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
where
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<()> + Send, )*
$last: FromRequest<(), B> + Send,
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B> + Send,
Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static,
S: Service<Request<B>, Error = Infallible>
I: Service<Request<B>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
I::Response: IntoResponse,
I::Future: Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
Expand All @@ -281,20 +258,21 @@ macro_rules! impl_service {
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);

let mut f = self.f.clone();
let state = Arc::clone(&self.state);

let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();

$(
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*

let req = Request::from_parts(parts, body);

let $last = match $last::from_request(req, &()).await {
let $last = match $last::from_request(req, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
Expand Down Expand Up @@ -342,14 +320,16 @@ impl_service!(
T16
);

impl<F, S, T> fmt::Debug for FromFn<F, S, T>
impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
where
S: fmt::Debug,
I: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.field("state", &self.state)
.finish()
}
}
Expand Down
4 changes: 3 additions & 1 deletion axum/src/middleware/mod.rs
Expand Up @@ -6,7 +6,9 @@ mod from_extractor;
mod from_fn;

pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next};
pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
};
pub use crate::extension::AddExtension;

pub mod future {
Expand Down

0 comments on commit 4c9edb4

Please sign in to comment.