Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware::{from_fn_with_state, from_fn_with_state_arc} #1342

Merged
merged 2 commits into from Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
without any routes will now result in a panic. Previously, this just did
nothing. [#1327]

## Middleware

- **added**: Add `middleware::from_fn_with_state` and
`middleware::from_fn_with_state_arc` to enable running extractors that require
state ([#1342])

[#1327]: https://github.com/tokio-rs/axum/pull/1327
[#1342]: https://github.com/tokio-rs/axum/pull/1342

# 0.6.0-rc.1 (23. August, 2022)

Expand Down
44 changes: 7 additions & 37 deletions axum/src/docs/middleware.md
Expand Up @@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.

# Accessing state in middleware

Handlers can access state using the [`State`] extractor but this isn't available
to middleware. Instead you have to pass the state directly to middleware using
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
fields (if you're implementing a [`tower::Layer`])
How to make state available to middleware depends on how the middleware is
written.

## Accessing state in `axum::middleware::from_fn`

```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};

#[derive(Clone)]
struct AppState {}

async fn my_middleware<B>(
state: AppState,
req: Request<B>,
next: Next<B>,
) -> Response {
next.run(req).await
}

async fn handler(_: State<AppState>) {}

let state = AppState {};

let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(middleware::from_fn(move |req, next| {
my_middleware(state.clone(), req, next)
}));
# let _: Router<_> = app;
```
Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).

## Accessing state in custom `tower::Layer`s

Expand Down Expand Up @@ -482,7 +449,10 @@ where
}

fn call(&mut self, req: Request<B>) -> Self::Future {
// do something with `self.state`
// Do something with `self.state`.
//
// See `axum::RequestExt` for how to run extractors directly from
// a `Request`.

self.inner.call(req)
}
Expand Down
132 changes: 56 additions & 76 deletions axum/src/middleware/from_fn.rs
Expand Up @@ -9,6 +9,7 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{util::BoxCloneService, ServiceBuilder};
Expand Down Expand Up @@ -90,82 +91,57 @@ use tower_service::Service;
/// # let app: Router = app;
/// ```
///
/// # Passing state
///
/// State can be passed to the function like so:
///
/// ```rust
/// use axum::{
/// Router,
/// http::{Request, StatusCode},
/// routing::get,
/// response::{IntoResponse, Response},
/// middleware::{self, Next}
/// };
///
/// #[derive(Clone)]
/// struct State { /* ... */ }
///
/// async fn my_middleware<B>(
/// req: Request<B>,
/// next: Next<B>,
/// state: State,
/// ) -> Response {
/// // ...
/// # ().into_response()
/// }
///
/// let state = State { /* ... */ };
/// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
from_fn_with_state((), f)
}

/// Create a middleware from an async function with the given state.
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn(move |req, next| {
/// my_middleware(req, next, state.clone())
/// }));
/// # let app: Router = app;
/// ```
/// See [`State`](crate::extract::State) for more details about accessing state.
///
/// Or via extensions:
/// # Example
///
/// ```rust
/// use axum::{
/// Router,
/// extract::Extension,
/// http::{Request, StatusCode},
/// routing::get,
/// response::{IntoResponse, Response},
/// middleware::{self, Next},
/// extract::State,
/// };
/// use tower::ServiceBuilder;
///
/// #[derive(Clone)]
/// struct State { /* ... */ }
/// struct AppState { /* ... */ }
///
/// async fn my_middleware<B>(
/// Extension(state): Extension<State>,
/// State(state): State<AppState>,
/// req: Request<B>,
/// next: Next<B>,
/// ) -> Response {
/// // ...
/// # ().into_response()
/// }
///
/// let state = State { /* ... */ };
/// let state = AppState { /* ... */ };
///
/// let app = Router::new()
/// let app = Router::with_state(state.clone())
/// .route("/", get(|| async { /* ... */ }))
/// .layer(
/// ServiceBuilder::new()
/// .layer(Extension(state))
/// .layer(middleware::from_fn(my_middleware)),
/// );
/// # let app: Router = app;
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
from_fn_with_state_arc(Arc::new(state), f)
}

/// Create a middleware from an async function with the given [`Arc`]'ed state.
///
/// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer {
f,
state,
_extractor: PhantomData,
}
}
Expand All @@ -175,98 +151,99 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
///
/// Created with [`from_fn`]. See that function for more details.
pub struct FromFnLayer<F, T> {
pub struct FromFnLayer<F, S, T> {
f: F,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, T> Clone for FromFnLayer<F, T>
impl<F, S, T> Clone for FromFnLayer<F, S, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}

impl<S, F, T> Layer<S> for FromFnLayer<F, T>
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
where
F: Clone,
{
type Service = FromFn<F, S, T>;
type Service = FromFn<F, S, I, T>;

fn layer(&self, inner: S) -> Self::Service {
fn layer(&self, inner: I) -> Self::Service {
FromFn {
f: self.f.clone(),
state: Arc::clone(&self.state),
inner,
_extractor: PhantomData,
}
}
}

impl<F, T> fmt::Debug for FromFnLayer<F, T> {
impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would
.field("f", &format_args!("{}", type_name::<F>()))
.field("state", &self.state)
.finish()
}
}

/// A middleware created from an async function.
///
/// Created with [`from_fn`]. See that function for more details.
pub struct FromFn<F, S, T> {
pub struct FromFn<F, S, I, T> {
f: F,
inner: S,
inner: I,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, S, T> Clone for FromFn<F, S, T>
impl<F, S, I, T> Clone for FromFn<F, S, I, T>
where
F: Clone,
S: Clone,
I: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

impl<F, S, T> Copy for FromFn<F, S, T>
where
F: Copy,
S: Copy,
{
}

macro_rules! impl_service {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
where
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<()> + Send, )*
$last: FromRequest<(), B> + Send,
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B> + Send,
Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static,
S: Service<Request<B>, Error = Infallible>
I: Service<Request<B>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
I::Response: IntoResponse,
I::Future: Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
Expand All @@ -281,20 +258,21 @@ macro_rules! impl_service {
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);

let mut f = self.f.clone();
let state = Arc::clone(&self.state);

let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();

$(
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*

let req = Request::from_parts(parts, body);

let $last = match $last::from_request(req, &()).await {
let $last = match $last::from_request(req, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
Expand Down Expand Up @@ -342,14 +320,16 @@ impl_service!(
T16
);

impl<F, S, T> fmt::Debug for FromFn<F, S, T>
impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
where
S: fmt::Debug,
I: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.field("state", &self.state)
.finish()
}
}
Expand Down
4 changes: 3 additions & 1 deletion axum/src/middleware/mod.rs
Expand Up @@ -6,7 +6,9 @@ mod from_extractor;
mod from_fn;

pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next};
pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
};
pub use crate::extension::AddExtension;

pub mod future {
Expand Down