Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Host extractor #827

Merged
merged 7 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
extension on the response ([#797])
- **added:** Implement `tower::Layer` for `Extension` ([#801])
- **added:** `extract::Host` for extracting the hostname of a request ([#827])
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
as field values.
Expand Down
111 changes: 111 additions & 0 deletions axum/src/extract/host.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use super::{
rejection::{FailedToResolveHost, HostRejection},
FromRequest, RequestParts,
};
use async_trait::async_trait;

const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host";

/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
/// - `X-Forwarded-Host` header
/// - `Host` header
/// - request target / URI
#[derive(Debug, Clone)]
pub struct Host(pub String);

#[async_trait]
impl<B> FromRequest<B> for Host
where
B: Send,
{
type Rejection = HostRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// todo: extract host from http::header::FORWARDED

if let Some(host) = req
.headers()
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}

if let Some(host) = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}

if let Some(host) = req.uri().host() {
return Ok(Host(host.to_owned()));
}
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};

fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
}

TestClient::new(Router::new().route("/", get(host_as_body)))
}

#[tokio::test]
async fn host_header() {
let original_host = "some-domain:123";
let host = test_client()
.get("/")
.header(http::header::HOST, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}

#[tokio::test]
async fn x_forwarded_host_header() {
let original_host = "some-domain:456";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}

#[tokio::test]
async fn x_forwarded_host_precedence_over_host_header() {
let x_forwarded_host_header = "some-domain:456";
let host_header = "some-domain:123";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
.header(http::header::HOST, host_header)
.send()
.await
.text()
.await;
assert_eq!(host, x_forwarded_host_header);
}

#[tokio::test]
async fn uri_host() {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
}
2 changes: 2 additions & 0 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod rejection;
pub mod ws;

mod content_length_limit;
mod host;
mod raw_query;
mod request_parts;

Expand All @@ -23,6 +24,7 @@ pub use self::{
connect_info::ConnectInfo,
content_length_limit::ContentLengthLimit,
extractor_middleware::extractor_middleware,
host::Host,
path::Path,
raw_query::RawQuery,
request_parts::{BodyStream, RawBody},
Expand Down
18 changes: 18 additions & 0 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ define_rejection! {
pub struct InvalidFormContentType;
}

define_rejection! {
#[status = BAD_REQUEST]
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
#[body = "No host found in request"]
/// Rejection type used if the [`Host`](super::Host) extractor is unable to
/// resolve a host.
pub struct FailedToResolveHost;
}

/// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type.
#[derive(Debug)]
Expand Down Expand Up @@ -160,6 +168,16 @@ composite_rejection! {
}
}

composite_rejection! {
/// Rejection used for [`Host`](super::Host).
///
/// Contains one variant for each way the [`Host`](super::Host) extractor
/// can fail.
pub enum HostRejection {
FailedToResolveHost,
}
}

#[cfg(feature = "matched-path")]
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
Expand Down