Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Form and Query rejections #1496

Merged
merged 6 commits into from Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 55 additions & 39 deletions axum-extra/src/extract/form.rs
@@ -1,16 +1,13 @@
use axum::{
async_trait,
body::HttpBody,
extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType},
FromRequest,
},
BoxError,
extract::{rejection::RawFormRejection, FromRequest, RawForm},
response::{IntoResponse, Response},
BoxError, Error, RequestExt,
};
use bytes::Bytes;
use http::{header, HeaderMap, Method, Request};
use http::{Request, StatusCode};
use serde::de::DeserializeOwned;
use std::ops::Deref;
use std::{fmt, ops::Deref};

/// Extractor that deserializes `application/x-www-form-urlencoded` requests
/// into some type.
Expand Down Expand Up @@ -65,41 +62,60 @@ where
{
type Rejection = FormRejection;

async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Form(value))
} else {
if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType::default().into());
}

let bytes = Bytes::from_request(req, state).await?;
let value = serde_html_form::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?;

Ok(Form(value))
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let RawForm(bytes) = req
.extract()
.await
.map_err(FormRejection::RawFormRejection)?;

serde_html_form::from_bytes::<T>(&bytes)
.map(Self)
.map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err)))
}
}

/// Rejection used for [`Form`].
///
/// Contains one variant for each way the [`Form`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
#[cfg(feature = "form")]
pub enum FormRejection {
#[allow(missing_docs)]
RawFormRejection(RawFormRejection),
#[allow(missing_docs)]
FailedToDeserializeForm(Error),
}

impl IntoResponse for FormRejection {
fn into_response(self) -> Response {
match self {
Self::RawFormRejection(inner) => inner.into_response(),
Self::FailedToDeserializeForm(inner) => (
StatusCode::BAD_REQUEST,
format!("Failed to deserialize form: {}", inner),
)
.into_response(),
}
}
}

// this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;
};

let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
};

content_type.starts_with(expected_content_type.as_ref())
impl fmt::Display for FormRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RawFormRejection(inner) => inner.fmt(f),
Self::FailedToDeserializeForm(inner) => inner.fmt(f),
}
}
}

impl std::error::Error for FormRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RawFormRejection(inner) => Some(inner),
Self::FailedToDeserializeForm(inner) => Some(inner),
}
}
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions axum-extra/src/extract/mod.rs
Expand Up @@ -25,10 +25,10 @@ pub use self::cookie::PrivateCookieJar;
pub use self::cookie::SignedCookieJar;

#[cfg(feature = "form")]
pub use self::form::Form;
pub use self::form::{Form, FormRejection};

#[cfg(feature = "query")]
pub use self::query::Query;
pub use self::query::{Query, QueryRejection};

#[cfg(feature = "json-lines")]
#[doc(no_inline)]
Expand Down
52 changes: 45 additions & 7 deletions axum-extra/src/extract/query.rs
@@ -1,13 +1,12 @@
use axum::{
async_trait,
extract::{
rejection::{FailedToDeserializeQueryString, QueryRejection},
FromRequestParts,
},
extract::FromRequestParts,
response::{IntoResponse, Response},
Error,
};
use http::request::Parts;
use http::{request::Parts, StatusCode};
use serde::de::DeserializeOwned;
use std::ops::Deref;
use std::{fmt, ops::Deref};

/// Extractor that deserializes query strings into some type.
///
Expand Down Expand Up @@ -69,7 +68,7 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Query(value))
}
}
Expand All @@ -82,6 +81,45 @@ impl<T> Deref for Query<T> {
}
}

/// Rejection used for [`Query`].
///
/// Contains one variant for each way the [`Query`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
#[cfg(feature = "query")]
pub enum QueryRejection {
#[allow(missing_docs)]
FailedToDeserializeQueryString(Error),
}

impl IntoResponse for QueryRejection {
fn into_response(self) -> Response {
match self {
Self::FailedToDeserializeQueryString(inner) => (
StatusCode::BAD_REQUEST,
format!("Failed to deserialize query string: {}", inner),
)
.into_response(),
}
}
}

impl fmt::Display for QueryRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FailedToDeserializeQueryString(inner) => inner.fmt(f),
}
}
}

impl std::error::Error for QueryRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::FailedToDeserializeQueryString(inner) => Some(inner),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Expand Up @@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied
(Response<()>, T1, T2, R)
(Response<()>, T1, T2, T3, R)
(Response<()>, T1, T2, T3, T4, R)
and 119 others
and 120 others
note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check`
--> tests/debug_handler/fail/wrong_return_type.rs:4:23
|
Expand Down
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -47,6 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])
- **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487])
- **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to
`FormRejection::FailedToDeserializeForm` ([#1496])

[#1352]: https://github.com/tokio-rs/axum/pull/1352
[#1368]: https://github.com/tokio-rs/axum/pull/1368
Expand All @@ -63,6 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
[#1487]: https://github.com/tokio-rs/axum/pull/1487
[#1496]: https://github.com/tokio-rs/axum/pull/1496

# 0.6.0-rc.2 (10. September, 2022)

Expand Down
4 changes: 2 additions & 2 deletions axum/src/extract/query.rs
Expand Up @@ -59,8 +59,8 @@ where

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
let value =
serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Query(value))
}
}
Expand Down
46 changes: 13 additions & 33 deletions axum/src/extract/rejection.rs
@@ -1,8 +1,5 @@
//! Rejection response types.

use crate::{BoxError, Error};
use axum_core::response::{IntoResponse, Response};

pub use crate::extract::path::FailedToDeserializePathParams;
pub use axum_core::extract::rejection::*;

Expand Down Expand Up @@ -73,39 +70,22 @@ define_rejection! {
pub struct FailedToResolveHost;
}

/// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type.
#[derive(Debug)]
pub struct FailedToDeserializeQueryString {
error: Error,
}

impl FailedToDeserializeQueryString {
#[doc(hidden)]
pub fn __private_new<E>(error: E) -> Self
where
E: Into<BoxError>,
{
FailedToDeserializeQueryString {
error: Error::new(error),
}
}
}

impl IntoResponse for FailedToDeserializeQueryString {
fn into_response(self) -> Response {
(http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to deserialize form"]
/// Rejection type used if the [`Form`](super::Form) extractor is unable to
/// deserialize the form into the target type.
pub struct FailedToDeserializeForm(Error);
}

impl std::fmt::Display for FailedToDeserializeQueryString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Failed to deserialize query string: {}", self.error,)
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to deserialize query string"]
/// Rejection type used if the [`Query`](super::Query) extractor is unable to
/// deserialize the form into the target type.
pub struct FailedToDeserializeQueryString(Error);
}

impl std::error::Error for FailedToDeserializeQueryString {}

composite_rejection! {
/// Rejection used for [`Query`](super::Query).
///
Expand All @@ -123,7 +103,7 @@ composite_rejection! {
/// can fail.
pub enum FormRejection {
InvalidFormContentType,
FailedToDeserializeQueryString,
FailedToDeserializeForm,
BytesRejection,
}
}
Expand Down
2 changes: 1 addition & 1 deletion axum/src/form.rs
Expand Up @@ -76,7 +76,7 @@ where
match req.extract().await {
Ok(RawForm(bytes)) => {
let value = serde_urlencoded::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?;
.map_err(FailedToDeserializeForm::from_err)?;
Ok(Form(value))
}
Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
Expand Down