Skip to content

Commit

Permalink
Add Host extractor (#827)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbillington committed Mar 6, 2022
1 parent b05a5c6 commit 843437b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 0 deletions.
1 change: 1 addition & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
extension on the response ([#797])
- **added:** Implement `tower::Layer` for `Extension` ([#801])
- **added:** `extract::Host` for extracting the hostname of a request ([#827])
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
as field values.
Expand Down
111 changes: 111 additions & 0 deletions axum/src/extract/host.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use super::{
rejection::{FailedToResolveHost, HostRejection},
FromRequest, RequestParts,
};
use async_trait::async_trait;

const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host";

/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
/// - `X-Forwarded-Host` header
/// - `Host` header
/// - request target / URI
#[derive(Debug, Clone)]
pub struct Host(pub String);

#[async_trait]
impl<B> FromRequest<B> for Host
where
B: Send,
{
type Rejection = HostRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// todo: extract host from http::header::FORWARDED

if let Some(host) = req
.headers()
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}

if let Some(host) = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}

if let Some(host) = req.uri().host() {
return Ok(Host(host.to_owned()));
}

Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};

fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
}

TestClient::new(Router::new().route("/", get(host_as_body)))
}

#[tokio::test]
async fn host_header() {
let original_host = "some-domain:123";
let host = test_client()
.get("/")
.header(http::header::HOST, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}

#[tokio::test]
async fn x_forwarded_host_header() {
let original_host = "some-domain:456";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}

#[tokio::test]
async fn x_forwarded_host_precedence_over_host_header() {
let x_forwarded_host_header = "some-domain:456";
let host_header = "some-domain:123";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
.header(http::header::HOST, host_header)
.send()
.await
.text()
.await;
assert_eq!(host, x_forwarded_host_header);
}

#[tokio::test]
async fn uri_host() {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}
}
2 changes: 2 additions & 0 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod rejection;
pub mod ws;

mod content_length_limit;
mod host;
mod raw_query;
mod request_parts;

Expand All @@ -23,6 +24,7 @@ pub use self::{
connect_info::ConnectInfo,
content_length_limit::ContentLengthLimit,
extractor_middleware::extractor_middleware,
host::Host,
path::Path,
raw_query::RawQuery,
request_parts::{BodyStream, RawBody},
Expand Down
18 changes: 18 additions & 0 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ define_rejection! {
pub struct InvalidFormContentType;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "No host found in request"]
/// Rejection type used if the [`Host`](super::Host) extractor is unable to
/// resolve a host.
pub struct FailedToResolveHost;
}

/// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type.
#[derive(Debug)]
Expand Down Expand Up @@ -160,6 +168,16 @@ composite_rejection! {
}
}

composite_rejection! {
/// Rejection used for [`Host`](super::Host).
///
/// Contains one variant for each way the [`Host`](super::Host) extractor
/// can fail.
pub enum HostRejection {
FailedToResolveHost,
}
}

#[cfg(feature = "matched-path")]
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
Expand Down

0 comments on commit 843437b

Please sign in to comment.