Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware::from_extractor_with_state #1396

Merged
merged 2 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
rejections, instead of `422 Unprocessable Entity` ([#1387])
- **added:** Add `middleware::from_extractor_with_state` and
`middleware::from_extractor_with_state_arc`

[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387
Expand Down
125 changes: 85 additions & 40 deletions axum/src/middleware/from_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
Expand Down Expand Up @@ -90,8 +91,25 @@ use tower_service::Service;
/// ```
///
/// [`Bytes`]: bytes::Bytes
pub fn from_extractor<E>() -> FromExtractorLayer<E> {
FromExtractorLayer(PhantomData)
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
from_extractor_with_state(())
}

/// Create a middleware from an extractor with the given state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
from_extractor_with_state_arc(Arc::new(state))
}

/// Create a middleware from an extractor with the given [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state_arc<E, S>(state: Arc<S>) -> FromExtractorLayer<E, S> {
FromExtractorLayer {
state,
_marker: PhantomData,
}
}

/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
Expand All @@ -100,28 +118,39 @@ pub fn from_extractor<E>() -> FromExtractorLayer<E> {
/// See [`from_extractor`] for more details.
///
/// [`Layer`]: tower::Layer
pub struct FromExtractorLayer<E>(PhantomData<fn() -> E>);
pub struct FromExtractorLayer<E, S> {
state: Arc<S>,
_marker: PhantomData<fn() -> E>,
}

impl<E> Clone for FromExtractorLayer<E> {
impl<E, S> Clone for FromExtractorLayer<E, S> {
fn clone(&self) -> Self {
Self(PhantomData)
Self {
state: Arc::clone(&self.state),
_marker: PhantomData,
}
}
}

impl<E> fmt::Debug for FromExtractorLayer<E> {
impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromExtractorLayer")
.field("state", &self.state)
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<E, S> Layer<S> for FromExtractorLayer<E> {
type Service = FromExtractor<S, E>;
impl<E, T, S> Layer<T> for FromExtractorLayer<E, S> {
type Service = FromExtractor<T, E, S>;

fn layer(&self, inner: S) -> Self::Service {
fn layer(&self, inner: T) -> Self::Service {
FromExtractor {
inner,
state: Arc::clone(&self.state),
_extractor: PhantomData,
}
}
Expand All @@ -130,62 +159,68 @@ impl<E, S> Layer<S> for FromExtractorLayer<E> {
/// Middleware that runs an extractor and discards the value.
///
/// See [`from_extractor`] for more details.
pub struct FromExtractor<S, E> {
inner: S,
pub struct FromExtractor<T, E, S> {
inner: T,
state: Arc<S>,
_extractor: PhantomData<fn() -> E>,
}

#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<FromExtractor<(), NotSendSync>>();
assert_sync::<FromExtractor<(), NotSendSync>>();
assert_send::<FromExtractor<(), NotSendSync, ()>>();
assert_sync::<FromExtractor<(), NotSendSync, ()>>();
}

impl<S, E> Clone for FromExtractor<S, E>
impl<T, E, S> Clone for FromExtractor<T, E, S>
where
S: Clone,
T: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: PhantomData,
}
}
}

impl<S, E> fmt::Debug for FromExtractor<S, E>
impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
where
T: fmt::Debug,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromExtractor")
.field("inner", &self.inner)
.field("state", &self.state)
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
where
E: FromRequestParts<()> + 'static,
E: FromRequestParts<S> + 'static,
B: Default + Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: IntoResponse,
T: Service<Request<B>> + Clone,
T::Response: IntoResponse,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = ResponseFuture<B, S, E>;
type Error = T::Error;
type Future = ResponseFuture<B, T, E, S>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<B>) -> Self::Future {
let state = Arc::clone(&self.state);
let extract_future = Box::pin(async move {
let (mut parts, body) = req.into_parts();
let extracted = E::from_request_parts(&mut parts, &()).await;
let extracted = E::from_request_parts(&mut parts, &state).await;
let req = Request::from_parts(parts, body);
(req, extracted)
});
Expand All @@ -202,39 +237,39 @@ where
pin_project! {
/// Response future for [`FromExtractor`].
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<B, S, E>
pub struct ResponseFuture<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
E: FromRequestParts<S>,
T: Service<Request<B>>,
{
#[pin]
state: State<B, S, E>,
svc: Option<S>,
state: State<B, T, E, S>,
svc: Option<T>,
}
}

pin_project! {
#[project = StateProj]
enum State<B, S, E>
enum State<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
E: FromRequestParts<S>,
T: Service<Request<B>>,
{
Extracting {
future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
},
Call { #[pin] future: S::Future },
Call { #[pin] future: T::Future },
}
}

impl<B, S, E> Future for ResponseFuture<B, S, E>
impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
where
E: FromRequestParts<()>,
S: Service<Request<B>>,
S::Response: IntoResponse,
E: FromRequestParts<S>,
T: Service<Request<B>>,
T::Response: IntoResponse,
B: Default,
{
type Output = Result<Response, S::Error>;
type Output = Result<Response, T::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
Expand Down Expand Up @@ -272,29 +307,35 @@ where
mod tests {
use super::*;
use crate::{handler::Handler, routing::get, test_helpers::*, Router};
use axum_core::extract::FromRef;
use http::{header, request::Parts, StatusCode};

#[tokio::test]
async fn test_from_extractor() {
#[derive(Clone)]
struct Secret(&'static str);

struct RequireAuth;

#[async_trait::async_trait]
impl<S> FromRequestParts<S> for RequireAuth
where
S: Send + Sync,
Secret: FromRef<S>,
{
type Rejection = StatusCode;

async fn from_request_parts(
parts: &mut Parts,
_state: &S,
state: &S,
) -> Result<Self, Self::Rejection> {
let Secret(secret) = Secret::from_ref(state);
if let Some(auth) = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
if auth == "secret" {
if auth == secret {
return Ok(Self);
}
}
Expand All @@ -305,7 +346,11 @@ mod tests {

async fn handler() {}

let app = Router::new().route("/", get(handler.layer(from_extractor::<RequireAuth>())));
let state = Secret("secret");
let app = Router::new().route(
"/",
get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
);

let client = TestClient::new(app);

Expand Down
5 changes: 4 additions & 1 deletion axum/src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
mod from_extractor;
mod from_fn;

pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
pub use self::from_extractor::{
from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor,
FromExtractorLayer,
};
pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
};
Expand Down