Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Forwarded in Host extractor #1078

Merged
merged 4 commits into from Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion axum/CHANGELOG.md
Expand Up @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Support resolving host name via `Forwarded` header in `Host`
extractor ([#1078])

[#1078]: https://github.com/tokio-rs/axum/pull/1078

# 0.5.7 (08. June, 2022)

Expand Down
62 changes: 61 additions & 1 deletion axum/src/extract/host.rs
Expand Up @@ -3,12 +3,14 @@ use super::{
FromRequest, RequestParts,
};
use async_trait::async_trait;
use http::header::{HeaderMap, FORWARDED};

const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";

/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
/// - `Forwarded` header
/// - `X-Forwarded-Host` header
/// - `Host` header
/// - request target / URI
Expand All @@ -26,7 +28,9 @@ where
type Rejection = HostRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// TODO: extract host from http::header::FORWARDED
if let Some(host) = parse_forwarded(req.headers()) {
return Ok(Host(host.to_owned()));
}

if let Some(host) = req
.headers()
Expand All @@ -52,10 +56,30 @@ where
}
}

#[allow(warnings)]
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;

// get the first set of values
let first_value = forwarded_values.split(',').nth(0)?;

// find the value of the `for` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
if key.trim().eq_ignore_ascii_case("for") {
Some(value.trim().trim_matches('"'))
} else {
None
}
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
})
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};
use http::header::HeaderName;

fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
Expand Down Expand Up @@ -111,4 +135,40 @@ mod tests {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}

#[test]
fn forwarded_parsing() {
// the basic case
let headers = header_map(&[(FORWARDED, "for=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// is case insensitive
let headers = header_map(&[(FORWARDED, "For=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// ipv6
let headers = header_map(&[(FORWARDED, "For=\"[2001:db8:cafe::17]:4711\"")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "[2001:db8:cafe::17]:4711");

// multiple values in one header
let headers = header_map(&[(FORWARDED, "for=192.0.2.60, for=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// multiple header values
let headers = header_map(&[(FORWARDED, "for=192.0.2.60"), (FORWARDED, "for=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
}

fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}