Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FromExtractor and deprecate extractor_middleware #957

Merged
merged 8 commits into from
Apr 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **added:** Add `response::ErrorResponse` and `response::Result` for
`IntoResponse`-based error handling ([#921])
- **added:** Add `middleware::from_extractor` and deprecate `extract::extractor_middleware` ([#957])

[#921]: https://github.com/tokio-rs/axum/pull/921
[#921]: https://github.com/tokio-rs/axum/pull/921
[#957]: https://github.com/tokio-rs/axum/pull/957

# 0.5.3 (19. April, 2022)

Expand Down
6 changes: 3 additions & 3 deletions axum/src/docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ Use [`axum::middleware::from_fn`] to write your middleware when:
- You don't intend to publish your middleware as a crate for others to use.
Middleware written like this are only compatible with axum.

## `axum::extract::extractor_middleware`
## `axum::middleware::from_extractor`

Use [`axum::extract::extractor_middleware`] to write your middleware when:
Use [`axum::middleware::from_extractor`] to write your middleware when:

- You have a type that you sometimes want to use as an extractor and sometimes
as a middleware. If you only need your type as a middleware prefer
Expand Down Expand Up @@ -442,7 +442,7 @@ extensions you need.
[`ServiceBuilder::map_response`]: tower::ServiceBuilder::map_response
[`ServiceBuilder::then`]: tower::ServiceBuilder::then
[`ServiceBuilder::and_then`]: tower::ServiceBuilder::and_then
[`axum::extract::extractor_middleware`]: crate::extract::extractor_middleware()
[`axum::middleware::from_extractor`]: crate::extract::extractor_middleware()
[`Handler::layer`]: crate::handler::Handler::layer
[`Router::layer`]: crate::routing::Router::layer
[`MethodRouter::layer`]: crate::routing::MethodRouter::layer
Expand Down
323 changes: 7 additions & 316 deletions axum/src/extract/extractor_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,324 +2,15 @@
//!
//! See [`extractor_middleware`] for more details.

use super::{FromRequest, RequestParts};
use crate::{
body::{Bytes, HttpBody},
response::{IntoResponse, Response},
BoxError,
};
use futures_util::{future::BoxFuture, ready};
use http::Request;
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
use crate::middleware::from_extractor;

pub use crate::middleware::{
future::FromExtractorResponseFuture as ResponseFuture, FromExtractor as ExtractorMiddleware,
FromExtractorLayer as ExtractorMiddlewareLayer,
};
use tower_layer::Layer;
use tower_service::Service;

/// Convert an extractor into a middleware.
///
/// If the extractor succeeds the value will be discarded and the inner service
/// will be called. If the extractor fails the rejection will be returned and
/// the inner service will _not_ be called.
///
/// This can be used to perform validation of requests if the validation doesn't
/// produce any useful output, and run the extractor for several handlers
/// without repeating it in the function signature.
///
/// Note that if the extractor consumes the request body, as `String` or
/// [`Bytes`] does, an empty body will be left in its place. Thus wont be
/// accessible to subsequent extractors or handlers.
///
/// # Example
///
/// ```rust
/// use axum::{
/// extract::{extractor_middleware, FromRequest, RequestParts},
/// routing::{get, post},
/// Router,
/// };
/// use http::StatusCode;
/// use async_trait::async_trait;
///
/// // An extractor that performs authorization.
/// struct RequireAuth;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for RequireAuth
/// where
/// B: Send,
/// {
/// type Rejection = StatusCode;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .get(http::header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok());
///
/// match auth_header {
/// Some(auth_header) if token_is_valid(auth_header) => {
/// Ok(Self)
/// }
/// _ => Err(StatusCode::UNAUTHORIZED),
/// }
/// }
/// }
///
/// fn token_is_valid(token: &str) -> bool {
/// // ...
/// # false
/// }
///
/// async fn handler() {
/// // If we get here the request has been authorized
/// }
///
/// async fn other_handler() {
/// // If we get here the request has been authorized
/// }
///
/// let app = Router::new()
/// .route("/", get(handler))
/// .route("/foo", post(other_handler))
/// // The extractor will run before all routes
/// .route_layer(extractor_middleware::<RequireAuth>());
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[deprecated(note = "Please use `axum::middleware::from_extractor` instead")]
pub fn extractor_middleware<E>() -> ExtractorMiddlewareLayer<E> {
ExtractorMiddlewareLayer(PhantomData)
}

/// [`Layer`] that applies [`ExtractorMiddleware`] that runs an extractor and
/// discards the value.
///
/// See [`extractor_middleware`] for more details.
///
/// [`Layer`]: tower::Layer
pub struct ExtractorMiddlewareLayer<E>(PhantomData<fn() -> E>);

impl<E> Clone for ExtractorMiddlewareLayer<E> {
fn clone(&self) -> Self {
Self(PhantomData)
}
}

impl<E> fmt::Debug for ExtractorMiddlewareLayer<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractorMiddleware")
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<E, S> Layer<S> for ExtractorMiddlewareLayer<E> {
type Service = ExtractorMiddleware<S, E>;

fn layer(&self, inner: S) -> Self::Service {
ExtractorMiddleware {
inner,
_extractor: PhantomData,
}
}
}

/// Middleware that runs an extractor and discards the value.
///
/// See [`extractor_middleware`] for more details.
pub struct ExtractorMiddleware<S, E> {
inner: S,
_extractor: PhantomData<fn() -> E>,
}

#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<ExtractorMiddleware<(), NotSendSync>>();
assert_sync::<ExtractorMiddleware<(), NotSendSync>>();
}

impl<S, E> Clone for ExtractorMiddleware<S, E>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_extractor: PhantomData,
}
}
}

impl<S, E> fmt::Debug for ExtractorMiddleware<S, E>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractorMiddleware")
.field("inner", &self.inner)
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
.finish()
}
}

impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
where
E: FromRequest<ReqBody> + 'static,
ReqBody: Default + Send + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response;
type Error = S::Error;
type Future = ResponseFuture<ReqBody, S, E>;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let extract_future = Box::pin(async move {
let mut req = super::RequestParts::new(req);
let extracted = E::from_request(&mut req).await;
(req, extracted)
});

ResponseFuture {
state: State::Extracting {
future: extract_future,
},
svc: Some(self.inner.clone()),
}
}
}

pin_project! {
/// Response future for [`ExtractorMiddleware`].
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
#[pin]
state: State<ReqBody, S, E>,
svc: Option<S>,
}
}

pin_project! {
#[project = StateProj]
enum State<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> },
Call { #[pin] future: S::Future },
}
}

impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Default,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Output = Result<Response, S::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();

let new_state = match this.state.as_mut().project() {
StateProj::Extracting { future } => {
let (req, extracted) = ready!(future.as_mut().poll(cx));

match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req);
State::Call { future }
}
Err(err) => {
let res = err.into_response();
return Poll::Ready(Ok(res));
}
}
}
StateProj::Call { future } => {
return future
.poll(cx)
.map(|result| result.map(|response| response.map(crate::body::boxed)));
}
};

this.state.set(new_state);
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{handler::Handler, routing::get, test_helpers::*, Router};
use http::StatusCode;

#[tokio::test]
async fn test_extractor_middleware() {
struct RequireAuth;

#[async_trait::async_trait]
impl<B> FromRequest<B> for RequireAuth
where
B: Send,
{
type Rejection = StatusCode;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
{
if auth == "secret" {
return Ok(Self);
}
}

Err(StatusCode::UNAUTHORIZED)
}
}

async fn handler() {}

let app = Router::new().route(
"/",
get(handler.layer(extractor_middleware::<RequireAuth>())),
);

let client = TestClient::new(app);

let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);

let res = client
.get("/")
.header(http::header::AUTHORIZATION, "secret")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
}
from_extractor()
}
1 change: 1 addition & 0 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod request_parts;
pub use axum_core::extract::{FromRequest, RequestParts};

#[doc(inline)]
#[allow(deprecated)]
pub use self::{
connect_info::ConnectInfo,
content_length_limit::ContentLengthLimit,
Expand Down