diff --git a/src/client/client.rs b/src/client/client.rs index 418f3fb4e9..96ddbaedff 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -972,6 +972,17 @@ impl Builder { self } + /// Set whether HTTP/1 connections will write header names as provided + /// at the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn http1_preserve_header_case(&mut self, val: bool) -> &mut Self { + self.conn_builder.h1_preserve_header_case(val); + self + } + /// Set whether HTTP/0.9 responses should be tolerated. /// /// Default is false. diff --git a/src/client/conn.rs b/src/client/conn.rs index b87600d85a..029e958731 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -124,6 +124,7 @@ pub struct Builder { pub(super) exec: Exec, h09_responses: bool, h1_title_case_headers: bool, + h1_preserve_header_case: bool, h1_read_buf_exact_size: Option, h1_max_buf_size: Option, #[cfg(feature = "http2")] @@ -497,6 +498,7 @@ impl Builder { h09_responses: false, h1_read_buf_exact_size: None, h1_title_case_headers: false, + h1_preserve_header_case: false, h1_max_buf_size: None, #[cfg(feature = "http2")] h2_builder: Default::default(), @@ -526,6 +528,11 @@ impl Builder { self } + pub(crate) fn h1_preserve_header_case(&mut self, enabled: bool) -> &mut Builder { + self.h1_preserve_header_case = enabled; + self + } + pub(super) fn h1_read_buf_exact_size(&mut self, sz: Option) -> &mut Builder { self.h1_read_buf_exact_size = sz; self.h1_max_buf_size = None; @@ -707,6 +714,9 @@ impl Builder { if opts.h1_title_case_headers { conn.set_title_case_headers(); } + if opts.h1_preserve_header_case { + conn.set_preserve_header_case(); + } if opts.h09_responses { conn.set_h09_responses(); } diff --git a/src/ext.rs b/src/ext.rs new file mode 100644 index 0000000000..d31e3e0bd2 --- /dev/null +++ b/src/ext.rs @@ -0,0 +1,51 @@ +//! HTTP extensions + +use bytes::Bytes; +#[cfg(feature = "http1")] +use http::header::{GetAll, HeaderName, IntoHeaderName}; +use http::HeaderMap; + +/// A map from header names to their original casing as received in an HTTP response. +/// +/// If an HTTP/1 response `res` is parsed on a connection whose option +/// [`http1_preserve_header_case`] was set to true and the response included +/// the following headers: +/// +/// ```ignore +/// x-Bread: Baguette +/// X-BREAD: Pain +/// x-bread: Ficelle +/// ``` +/// +/// Then `res.extensions().get::()` will return a map with: +/// +/// ```ignore +/// HeaderCaseMap({ +/// "x-bread": ["x-Bread", "X-BREAD", "x-bread"], +/// }) +/// ``` +/// +/// [`http1_preserve_header_case`]: /client/struct.Client.html#method.http1_preserve_header_case +#[derive(Clone, Debug, Default)] +pub struct HeaderCaseMap(HeaderMap); + +#[cfg(feature = "http1")] +impl HeaderCaseMap { + /// Returns a view of all spellings associated with that header name, + /// in the order they were found. + pub fn get_all(&self, name: &HeaderName) -> GetAll<'_, Bytes> { + self.0.get_all(name) + } + + #[cfg(any(test, feature = "ffi"))] + pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { + self.0.insert(name, orig); + } + + pub(crate) fn append(&mut self, name: N, orig: Bytes) + where + N: IntoHeaderName, + { + self.0.append(name, orig); + } +} diff --git a/src/ffi/client.rs b/src/ffi/client.rs index 0351214e09..9be4f5a04d 100644 --- a/src/ffi/client.rs +++ b/src/ffi/client.rs @@ -106,8 +106,11 @@ unsafe impl AsTaskType for hyper_clientconn { ffi_fn! { /// Creates a new set of HTTP clientconn options to be used in a handshake. fn hyper_clientconn_options_new() -> *mut hyper_clientconn_options { + let mut builder = conn::Builder::new(); + builder.h1_preserve_header_case(true); + Box::into_raw(Box::new(hyper_clientconn_options { - builder: conn::Builder::new(), + builder, exec: WeakExec::new(), })) } ?= std::ptr::null_mut() diff --git a/src/ffi/http_types.rs b/src/ffi/http_types.rs index 1fce28902a..8dccbda0ef 100644 --- a/src/ffi/http_types.rs +++ b/src/ffi/http_types.rs @@ -6,6 +6,7 @@ use super::body::hyper_body; use super::error::hyper_code; use super::task::{hyper_task_return_type, AsTaskType}; use super::HYPER_ITER_CONTINUE; +use crate::ext::HeaderCaseMap; use crate::header::{HeaderName, HeaderValue}; use crate::{Body, HeaderMap, Method, Request, Response, Uri}; @@ -24,10 +25,6 @@ pub struct hyper_headers { orig_casing: HeaderCaseMap, } -// Will probably be moved to `hyper::ext::http1` -#[derive(Debug, Default)] -pub(crate) struct HeaderCaseMap(HeaderMap); - #[derive(Debug)] pub(crate) struct ReasonPhrase(pub(crate) Bytes); @@ -370,25 +367,6 @@ unsafe fn raw_name_value( Ok((name, value, orig_name)) } -// ===== impl HeaderCaseMap ===== - -impl HeaderCaseMap { - pub(crate) fn get_all(&self, name: &HeaderName) -> http::header::GetAll<'_, Bytes> { - self.0.get_all(name) - } - - pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { - self.0.insert(name, orig); - } - - pub(crate) fn append(&mut self, name: N, orig: Bytes) - where - N: http::header::IntoHeaderName, - { - self.0.append(name, orig); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 059f8821c6..132a054eff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,6 +80,7 @@ mod cfg; mod common; pub mod body; mod error; +pub mod ext; #[cfg(test)] mod mock; #[cfg(any(feature = "http1", feature = "http2",))] diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index ce0848ddea..6c693182ec 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -44,7 +44,6 @@ where error: None, keep_alive: KA::Busy, method: None, - #[cfg(feature = "ffi")] preserve_header_case: false, title_case_headers: false, h09_responses: false, @@ -74,11 +73,15 @@ where self.io.set_read_buf_exact_size(sz); } - #[cfg(feature = "client")] pub(crate) fn set_title_case_headers(&mut self) { self.state.title_case_headers = true; } + #[cfg(feature = "client")] + pub(crate) fn set_preserve_header_case(&mut self) { + self.state.preserve_header_case = true; + } + #[cfg(feature = "client")] pub(crate) fn set_h09_responses(&mut self) { self.state.h09_responses = true; @@ -150,7 +153,6 @@ where ParseContext { cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, - #[cfg(feature = "ffi")] preserve_header_case: self.state.preserve_header_case, h09_responses: self.state.h09_responses, } @@ -488,16 +490,6 @@ where self.enforce_version(&mut head); - // Maybe check if we should preserve header casing on received - // message headers... - #[cfg(feature = "ffi")] - { - if T::is_client() && !self.state.preserve_header_case { - self.state.preserve_header_case = - head.extensions.get::().is_some(); - } - } - let buf = self.io.headers_buf(); match super::role::encode_headers::( Encode { @@ -760,7 +752,6 @@ struct State { /// This is used to know things such as if the message can include /// a body or not. method: Option, - #[cfg(feature = "ffi")] preserve_header_case: bool, title_case_headers: bool, h09_responses: bool, diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index c7ce48664b..e35d3390b9 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -159,7 +159,6 @@ where ParseContext { cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, - #[cfg(feature = "ffi")] preserve_header_case: parse_ctx.preserve_header_case, h09_responses: parse_ctx.h09_responses, }, @@ -639,7 +638,6 @@ mod tests { let parse_ctx = ParseContext { cached_headers: &mut None, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }; diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 01a9253fa3..ec9691a216 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -70,7 +70,6 @@ pub(crate) struct ParsedMessage { pub(crate) struct ParseContext<'a> { cached_headers: &'a mut Option, req_method: &'a mut Option, - #[cfg(feature = "ffi")] preserve_header_case: bool, h09_responses: bool, } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index ea9dc96be1..258222b7d9 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -5,16 +5,17 @@ use std::fmt::{self, Write}; use std::mem; -#[cfg(feature = "ffi")] +#[cfg(any(test, feature = "ffi", feature = "server"))] use bytes::Bytes; use bytes::BytesMut; -use http::header::{self, Entry, HeaderName, HeaderValue}; +use http::header::{self, Entry, HeaderName, HeaderValue, ValueIter}; use http::{HeaderMap, Method, StatusCode, Version}; use crate::body::DecodedLength; #[cfg(feature = "server")] use crate::common::date; use crate::error::Parse; +use crate::ext::HeaderCaseMap; use crate::headers; use crate::proto::h1::{ Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage, @@ -190,6 +191,7 @@ impl Http1Transaction for Server { let mut is_te = false; let mut is_te_chunked = false; let mut wants_upgrade = subject.0 == Method::CONNECT; + let mut header_case_map = HeaderCaseMap::default(); let mut headers = ctx.cached_headers.take().unwrap_or_else(HeaderMap::new); @@ -260,6 +262,10 @@ impl Http1Transaction for Server { _ => (), } + if ctx.preserve_header_case { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + headers.append(name, value); } @@ -268,6 +274,12 @@ impl Http1Transaction for Server { return Err(Parse::Header); } + let mut extensions = http::Extensions::default(); + + if ctx.preserve_header_case { + extensions.insert(header_case_map); + } + *ctx.req_method = Some(subject.0.clone()); Ok(Some(ParsedMessage { @@ -275,7 +287,7 @@ impl Http1Transaction for Server { version, subject, headers, - extensions: http::Extensions::default(), + extensions, }, decode: decoder, expect_continue, @@ -284,20 +296,13 @@ impl Http1Transaction for Server { })) } - fn encode( - mut msg: Encode<'_, Self::Outgoing>, - mut dst: &mut Vec, - ) -> crate::Result { + fn encode(mut msg: Encode<'_, Self::Outgoing>, dst: &mut Vec) -> crate::Result { trace!( "Server::encode status={:?}, body={:?}, req_method={:?}", msg.head.subject, msg.body, msg.req_method ); - debug_assert!( - !msg.title_case_headers, - "no server config for title case headers" - ); let mut wrote_len = false; @@ -305,7 +310,7 @@ impl Http1Transaction for Server { // This is because Service only allows returning a single Response, and // so if you try to reply with a e.g. 100 Continue, you have no way of // replying with the latter status code response. - let (ret, mut is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + let (ret, is_last) = if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { (Ok(()), true) } else if msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success() { // Sending content-length or transfer-encoding header on 2xx response @@ -326,9 +331,6 @@ impl Http1Transaction for Server { // pushing some bytes onto the `dst`. In those cases, we don't want to send // the half-pushed message, so rewind to before. let orig_len = dst.len(); - let rewind = |dst: &mut Vec| { - dst.truncate(orig_len); - }; let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE; dst.reserve(init_cap); @@ -359,6 +361,217 @@ impl Http1Transaction for Server { extend(dst, b"\r\n"); } + let orig_headers; + let extensions = mem::take(&mut msg.head.extensions); + let orig_headers = match extensions.get::() { + None if msg.title_case_headers => { + orig_headers = HeaderCaseMap::default(); + Some(&orig_headers) + } + orig_headers => orig_headers, + }; + let encoder = if let Some(orig_headers) = orig_headers { + Self::encode_headers_with_original_case( + msg, + dst, + is_last, + orig_len, + wrote_len, + orig_headers, + )? + } else { + Self::encode_headers_with_lower_case(msg, dst, is_last, orig_len, wrote_len)? + }; + + ret.map(|()| encoder) + } + + fn on_error(err: &crate::Error) -> Option> { + use crate::error::Kind; + let status = match *err.kind() { + Kind::Parse(Parse::Method) + | Kind::Parse(Parse::Header) + | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, + Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + + fn is_server() -> bool { + true + } + + fn update_date() { + date::update(); + } +} + +#[cfg(feature = "server")] +impl Server { + fn can_have_body(method: &Option, status: StatusCode) -> bool { + Server::can_chunked(method, status) + } + + fn can_chunked(method: &Option, status: StatusCode) -> bool { + if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() + { + false + } else if status.is_informational() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn can_have_content_length(method: &Option, status: StatusCode) -> bool { + if status.is_informational() || method == &Some(Method::CONNECT) && status.is_success() { + false + } else { + match status { + StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, + _ => true, + } + } + } + + fn encode_headers_with_lower_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec, + is_last: bool, + orig_len: usize, + wrote_len: bool, + ) -> crate::Result { + struct LowercaseWriter; + + impl HeaderNameWriter for LowercaseWriter { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec, + line: &str, + _: (HeaderName, &str), + ) { + extend(dst, line.as_bytes()) + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec, + name_with_colon: &str, + _: HeaderName, + ) { + extend(dst, name_with_colon.as_bytes()) + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName) { + extend(dst, name.as_str().as_bytes()) + } + } + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, LowercaseWriter) + } + + #[cold] + #[inline(never)] + fn encode_headers_with_original_case( + msg: Encode<'_, StatusCode>, + dst: &mut Vec, + is_last: bool, + orig_len: usize, + wrote_len: bool, + orig_headers: &HeaderCaseMap, + ) -> crate::Result { + struct OrigCaseWriter<'map> { + map: &'map HeaderCaseMap, + current: Option<(HeaderName, ValueIter<'map, Bytes>)>, + title_case_headers: bool, + } + + impl HeaderNameWriter for OrigCaseWriter<'_> { + #[inline] + fn write_full_header_line( + &mut self, + dst: &mut Vec, + _: &str, + (name, rest): (HeaderName, &str), + ) { + self.write_header_name(dst, &name); + extend(dst, rest.as_bytes()); + } + + #[inline] + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec, + _: &str, + name: HeaderName, + ) { + self.write_header_name(dst, &name); + extend(dst, b": "); + } + + #[inline] + fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName) { + let Self { + map, + ref mut current, + title_case_headers, + } = *self; + if current.as_ref().map_or(true, |(last, _)| last != name) { + *current = None; + } + let (_, values) = + current.get_or_insert_with(|| (name.clone(), map.get_all(name).into_iter())); + + if let Some(orig_name) = values.next() { + extend(dst, orig_name); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); + } else { + extend(dst, name.as_str().as_bytes()); + } + } + } + + let header_name_writer = OrigCaseWriter { + map: orig_headers, + current: None, + title_case_headers: msg.title_case_headers, + }; + + Self::encode_headers(msg, dst, is_last, orig_len, wrote_len, header_name_writer) + } + + #[inline] + fn encode_headers( + msg: Encode<'_, StatusCode>, + mut dst: &mut Vec, + mut is_last: bool, + orig_len: usize, + mut wrote_len: bool, + mut header_name_writer: W, + ) -> crate::Result + where + W: HeaderNameWriter, + { + // In some error cases, we don't know about the invalid message until already + // pushing some bytes onto the `dst`. In those cases, we don't want to send + // the half-pushed message, so rewind to before. + let rewind = |dst: &mut Vec| { + dst.truncate(orig_len); + }; + let mut encoder = Encoder::length(0); let mut wrote_date = false; let mut cur_name = None; @@ -422,7 +635,11 @@ impl Http1Transaction for Server { if !is_name_written { encoder = Encoder::length(known_len); - extend(dst, b"content-length: "); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); extend(dst, value.as_bytes()); wrote_len = true; is_name_written = true; @@ -450,7 +667,11 @@ impl Http1Transaction for Server { } else { // we haven't written content-length yet! encoder = Encoder::length(len); - extend(dst, b"content-length: "); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); extend(dst, value.as_bytes()); wrote_len = true; is_name_written = true; @@ -505,7 +726,11 @@ impl Http1Transaction for Server { if !is_name_written { encoder = Encoder::chunked(); is_name_written = true; - extend(dst, b"transfer-encoding: "); + header_name_writer.write_header_name_with_colon( + dst, + "transfer-encoding: ", + header::TRANSFER_ENCODING, + ); extend(dst, value.as_bytes()); } else { extend(dst, b", "); @@ -519,7 +744,11 @@ impl Http1Transaction for Server { } if !is_name_written { is_name_written = true; - extend(dst, b"connection: "); + header_name_writer.write_header_name_with_colon( + dst, + "connection: ", + header::CONNECTION, + ); extend(dst, value.as_bytes()); } else { extend(dst, b", "); @@ -541,7 +770,7 @@ impl Http1Transaction for Server { "{:?} set is_name_written and didn't continue loop", name, ); - extend(dst, name.as_str().as_bytes()); + header_name_writer.write_header_name(dst, name); extend(dst, b": "); extend(dst, value.as_bytes()); extend(dst, b"\r\n"); @@ -557,13 +786,21 @@ impl Http1Transaction for Server { { Encoder::close_delimited() } else { - extend(dst, b"transfer-encoding: chunked\r\n"); + header_name_writer.write_full_header_line( + dst, + "transfer-encoding: chunked\r\n", + (header::TRANSFER_ENCODING, ": chunked\r\n"), + ); Encoder::chunked() } } None | Some(BodyLength::Known(0)) => { if Server::can_have_content_length(msg.req_method, msg.head.subject) { - extend(dst, b"content-length: 0\r\n"); + header_name_writer.write_full_header_line( + dst, + "content-length: 0\r\n", + (header::CONTENT_LENGTH, ": 0\r\n"), + ) } Encoder::length(0) } @@ -571,7 +808,11 @@ impl Http1Transaction for Server { if !Server::can_have_content_length(msg.req_method, msg.head.subject) { Encoder::length(0) } else { - extend(dst, b"content-length: "); + header_name_writer.write_header_name_with_colon( + dst, + "content-length: ", + header::CONTENT_LENGTH, + ); let _ = ::itoa::write(&mut dst, len); extend(dst, b"\r\n"); Encoder::length(len) @@ -592,72 +833,32 @@ impl Http1Transaction for Server { // cached date is much faster than formatting every request if !wrote_date { dst.reserve(date::DATE_VALUE_LENGTH + 8); - extend(dst, b"date: "); + header_name_writer.write_header_name_with_colon(dst, "date: ", header::DATE); date::extend(dst); extend(dst, b"\r\n\r\n"); } else { extend(dst, b"\r\n"); } - ret.map(|()| encoder.set_last(is_last)) - } - - fn on_error(err: &crate::Error) -> Option> { - use crate::error::Kind; - let status = match *err.kind() { - Kind::Parse(Parse::Method) - | Kind::Parse(Parse::Header) - | Kind::Parse(Parse::Uri) - | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, - Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, - _ => return None, - }; - - debug!("sending automatic response ({}) for parse error", status); - let mut msg = MessageHead::default(); - msg.subject = status; - Some(msg) - } - - fn is_server() -> bool { - true - } - - fn update_date() { - date::update(); + Ok(encoder.set_last(is_last)) } } #[cfg(feature = "server")] -impl Server { - fn can_have_body(method: &Option, status: StatusCode) -> bool { - Server::can_chunked(method, status) - } - - fn can_chunked(method: &Option, status: StatusCode) -> bool { - if method == &Some(Method::HEAD) || method == &Some(Method::CONNECT) && status.is_success() - { - false - } else if status.is_informational() { - false - } else { - match status { - StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, - _ => true, - } - } - } - - fn can_have_content_length(method: &Option, status: StatusCode) -> bool { - if status.is_informational() || method == &Some(Method::CONNECT) && status.is_success() { - false - } else { - match status { - StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED => false, - _ => true, - } - } - } +trait HeaderNameWriter { + fn write_full_header_line( + &mut self, + dst: &mut Vec, + line: &str, + name_value_pair: (HeaderName, &str), + ); + fn write_header_name_with_colon( + &mut self, + dst: &mut Vec, + name_with_colon: &str, + name: HeaderName, + ); + fn write_header_name(&mut self, dst: &mut Vec, name: &HeaderName); } #[cfg(feature = "client")] @@ -731,8 +932,7 @@ impl Http1Transaction for Client { let mut keep_alive = version == Version::HTTP_11; - #[cfg(feature = "ffi")] - let mut header_case_map = crate::ffi::HeaderCaseMap::default(); + let mut header_case_map = HeaderCaseMap::default(); headers.reserve(headers_len); for header in &headers_indices[..headers_len] { @@ -750,7 +950,6 @@ impl Http1Transaction for Client { } } - #[cfg(feature = "ffi")] if ctx.preserve_header_case { header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); } @@ -758,10 +957,8 @@ impl Http1Transaction for Client { headers.append(name, value); } - #[allow(unused_mut)] let mut extensions = http::Extensions::default(); - #[cfg(feature = "ffi")] if ctx.preserve_header_case { extensions.insert(header_case_map); } @@ -829,26 +1026,17 @@ impl Http1Transaction for Client { } extend(dst, b"\r\n"); - #[cfg(feature = "ffi")] - { - if msg.title_case_headers { - write_headers_title_case(&msg.head.headers, dst); - } else if let Some(orig_headers) = - msg.head.extensions.get::() - { - write_headers_original_case(&msg.head.headers, orig_headers, dst); - } else { - write_headers(&msg.head.headers, dst); - } - } - - #[cfg(not(feature = "ffi"))] - { - if msg.title_case_headers { - write_headers_title_case(&msg.head.headers, dst); - } else { - write_headers(&msg.head.headers, dst); - } + if let Some(orig_headers) = msg.head.extensions.get::() { + write_headers_original_case( + &msg.head.headers, + orig_headers, + dst, + msg.title_case_headers, + ); + } else if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); } extend(dst, b"\r\n"); @@ -1161,12 +1349,12 @@ fn write_headers(headers: &HeaderMap, dst: &mut Vec) { } } -#[cfg(feature = "ffi")] #[cold] fn write_headers_original_case( headers: &HeaderMap, - orig_case: &crate::ffi::HeaderCaseMap, + orig_case: &HeaderCaseMap, dst: &mut Vec, + title_case_headers: bool, ) { // For each header name/value pair, there may be a value in the casemap // that corresponds to the HeaderValue. So, we iterator all the keys, @@ -1179,6 +1367,8 @@ fn write_headers_original_case( for value in headers.get_all(name) { if let Some(orig_name) = names.next() { extend(dst, orig_name); + } else if title_case_headers { + title_case(dst, name.as_str().as_bytes()); } else { extend(dst, name.as_str().as_bytes()); } @@ -1231,7 +1421,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut method, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1254,7 +1443,6 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }; @@ -1272,7 +1460,6 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }; @@ -1288,7 +1475,6 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: true, }; @@ -1306,7 +1492,6 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }; @@ -1314,6 +1499,38 @@ mod tests { assert_eq!(raw, H09_RESPONSE); } + #[test] + fn test_parse_preserve_header_case_in_request() { + let mut raw = + BytesMut::from("GET / HTTP/1.1\r\nHost: hyper.rs\r\nX-BREAD: baguette\r\n\r\n"); + let ctx = ParseContext { + cached_headers: &mut None, + req_method: &mut None, + preserve_header_case: true, + h09_responses: false, + }; + let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); + let orig_headers = parsed_message + .head + .extensions + .get::() + .unwrap(); + assert_eq!( + orig_headers + .get_all(&HeaderName::from_static("host")) + .into_iter() + .collect::>(), + vec![&Bytes::from("Host")] + ); + assert_eq!( + orig_headers + .get_all(&HeaderName::from_static("x-bread")) + .into_iter() + .collect::>(), + vec![&Bytes::from("X-BREAD")] + ); + } + #[test] fn test_decoder_request() { fn parse(s: &str) -> ParsedMessage { @@ -1323,7 +1540,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1339,7 +1555,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1554,7 +1769,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, } @@ -1570,7 +1784,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(m), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1586,7 +1799,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1874,6 +2086,75 @@ mod tests { assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n*-*: o_o\r\n\r\n".to_vec()); } + #[test] + fn test_client_request_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\n\r\n" + .as_ref(), + ); + } + #[test] + fn test_client_request_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Client::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + assert_eq!( + &*vec, + b"GET / HTTP/1.1\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\n\r\n" + .as_ref(), + ); + } + #[test] fn test_server_encode_connect_method() { let mut head = MessageHead::default(); @@ -1894,6 +2175,104 @@ mod tests { assert!(encoder.is_last()); } + #[test] + fn test_server_response_encode_title_case() { + use crate::proto::BodyLength; + use http::header::HeaderValue; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: application/json\r\n"; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: false, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\ncontent-type: application/json\r\ndate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + + #[test] + fn test_server_response_encode_orig_and_title_case() { + use crate::proto::BodyLength; + use http::header::{HeaderValue, CONTENT_LENGTH}; + + let mut head = MessageHead::default(); + head.headers + .insert("content-length", HeaderValue::from_static("10")); + head.headers + .insert("content-type", HeaderValue::from_static("application/json")); + + let mut orig_headers = HeaderCaseMap::default(); + orig_headers.insert(CONTENT_LENGTH, "CONTENT-LENGTH".into()); + head.extensions.insert(orig_headers); + + let mut vec = Vec::new(); + Server::encode( + Encode { + head: &mut head, + body: Some(BodyLength::Known(10)), + keep_alive: true, + req_method: &mut None, + title_case_headers: true, + }, + &mut vec, + ) + .unwrap(); + + let expected_response = + b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 10\r\nContent-Type: application/json\r\nDate: "; + + assert_eq!(&vec[..expected_response.len()], &expected_response[..]); + } + #[test] fn parse_header_htabs() { let mut bytes = BytesMut::from("HTTP/1.1 200 OK\r\nserver: hello\tworld\r\n\r\n"); @@ -1902,7 +2281,6 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -1913,17 +2291,16 @@ mod tests { assert_eq!(parsed.head.headers["server"], "hello\tworld"); } - #[cfg(feature = "ffi")] #[test] fn test_write_headers_orig_case_empty_value() { let mut headers = HeaderMap::new(); let name = http::header::HeaderName::from_static("x-empty"); headers.insert(&name, "".parse().expect("parse empty")); - let mut orig_cases = crate::ffi::HeaderCaseMap::default(); + let mut orig_cases = HeaderCaseMap::default(); orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); let mut dst = Vec::new(); - super::write_headers_original_case(&headers, &orig_cases, &mut dst); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); assert_eq!( dst, b"X-EmptY:\r\n", @@ -1931,7 +2308,6 @@ mod tests { ); } - #[cfg(feature = "ffi")] #[test] fn test_write_headers_orig_case_multiple_entries() { let mut headers = HeaderMap::new(); @@ -1939,12 +2315,12 @@ mod tests { headers.insert(&name, "a".parse().unwrap()); headers.append(&name, "b".parse().unwrap()); - let mut orig_cases = crate::ffi::HeaderCaseMap::default(); + let mut orig_cases = HeaderCaseMap::default(); orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); let mut dst = Vec::new(); - super::write_headers_original_case(&headers, &orig_cases, &mut dst); + super::write_headers_original_case(&headers, &orig_cases, &mut dst, false); assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); } @@ -1984,7 +2360,6 @@ mod tests { ParseContext { cached_headers: &mut headers, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, @@ -2020,7 +2395,6 @@ mod tests { ParseContext { cached_headers: &mut headers, req_method: &mut None, - #[cfg(feature = "ffi")] preserve_header_case: false, h09_responses: false, }, diff --git a/src/server/conn.rs b/src/server/conn.rs index 5137708fcb..0cef9d5e78 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -88,6 +88,7 @@ pub struct Http { exec: E, h1_half_close: bool, h1_keep_alive: bool, + h1_title_case_headers: bool, #[cfg(feature = "http2")] h2_builder: proto::h2::server::Config, mode: ConnectionMode, @@ -234,6 +235,7 @@ impl Http { exec: Exec::Default, h1_half_close: false, h1_keep_alive: true, + h1_title_case_headers: false, #[cfg(feature = "http2")] h2_builder: Default::default(), mode: ConnectionMode::default(), @@ -286,6 +288,19 @@ impl Http { self } + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_title_case_headers(&mut self, enabled: bool) -> &mut Self { + self.h1_title_case_headers = enabled; + self + } + /// Sets whether HTTP2 is required. /// /// Default is false @@ -459,6 +474,7 @@ impl Http { exec, h1_half_close: self.h1_half_close, h1_keep_alive: self.h1_keep_alive, + h1_title_case_headers: self.h1_title_case_headers, #[cfg(feature = "http2")] h2_builder: self.h2_builder, mode: self.mode, @@ -514,6 +530,9 @@ impl Http { if self.h1_half_close { conn.set_allow_half_close(); } + if self.h1_title_case_headers { + conn.set_title_case_headers(); + } conn.set_flush_pipeline(self.pipeline_flush); if let Some(max) = self.max_buf_size { conn.set_max_buf_size(max); diff --git a/src/server/server.rs b/src/server/server.rs index 48cc6e2803..e02ab94b16 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -231,6 +231,19 @@ impl Builder { self } + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_title_case_headers(&mut self, val: bool) -> &mut Self { + self.protocol.http1_title_case_headers(val); + self + } + /// Sets whether HTTP/1 is required. /// /// Default is `false`.