Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use failable headers methods from http #3270

Draft
wants to merge 1 commit into
base: 0.14.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Expand Up @@ -256,3 +256,6 @@ required-features = ["full"]
name = "server"
path = "tests/server.rs"
required-features = ["full"]

[patch.crates-io]
http = { git = "https://github.com/glebpom/http", branch = "failable-headers-allocations" }
3 changes: 2 additions & 1 deletion examples/upgrades.rs
Expand Up @@ -66,7 +66,8 @@ async fn server_upgrade(mut req: Request<Body>) -> Result<Response<Body>> {
// made-up 'foobar' protocol.
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
res.headers_mut()
.insert(UPGRADE, HeaderValue::from_static("foobar"));
.try_insert(UPGRADE, HeaderValue::from_static("foobar"))
.expect("FIXME");
Ok(res)
}

Expand Down
11 changes: 9 additions & 2 deletions src/client/client.rs
Expand Up @@ -257,7 +257,8 @@ where

if self.config.set_host {
let uri = req.uri().clone();
req.headers_mut().entry(HOST).or_insert_with(|| {
let entry = headers_op!(req.headers_mut().try_entry(HOST), error);
headers_op!(entry.or_try_insert_with(|| {
let hostname = uri.host().expect("authority implies host");
if let Some(port) = get_non_default_port(&uri) {
let s = format!("{}:{}", hostname, port);
Expand All @@ -266,7 +267,7 @@ where
HeaderValue::from_str(hostname)
}
.expect("uri host is valid header value")
});
}), capacity);
}

// CONNECT always sends authority-form, so check it first...
Expand Down Expand Up @@ -752,6 +753,12 @@ enum ClientError<B> {
},
}

impl<B> From<crate::error::Parse> for ClientError<B> {
fn from(e: crate::error::Parse) -> Self {
ClientError::Normal(e.into())
}
}

impl<B> ClientError<B> {
fn map_with_reused(conn_reused: bool) -> impl Fn((crate::Error, Option<Request<B>>)) -> Self {
move |(err, orig_req)| {
Expand Down
46 changes: 46 additions & 0 deletions src/error.rs
Expand Up @@ -65,6 +65,9 @@ pub(super) enum Kind {
/// A general error from h2.
#[cfg(feature = "http2")]
Http2,

/// Error with HTTP request or response
Http(http::Error),
}

#[derive(Debug)]
Expand Down Expand Up @@ -518,6 +521,7 @@ impl Error {
Kind::User(User::DispatchGone) => "dispatch task is gone",
#[cfg(feature = "ffi")]
Kind::User(User::AbortedByCallback) => "operation aborted by an application callback",
Kind::Http(ref _e) => "FIXME: HTTP request or response building error"
}
}
}
Expand Down Expand Up @@ -559,6 +563,20 @@ impl From<Parse> for Error {
}
}

#[doc(hidden)]
impl From<http::CapacityOverflow> for Error {
fn from(err: http::CapacityOverflow) -> Error {
Error::new(Kind::Http(err.into()))
}
}

#[doc(hidden)]
impl From<http::Error> for Error {
fn from(err: http::Error) -> Error {
Error::new(Kind::Http(err))
}
}

#[cfg(feature = "http1")]
impl Parse {
pub(crate) fn content_length_invalid() -> Self {
Expand Down Expand Up @@ -628,6 +646,34 @@ impl fmt::Display for TimedOut {

impl StdError for TimedOut {}

macro_rules! headers_op {
($op:expr, capacity) => {
match $op {
Err(http::CapacityOverflow { .. }) => {
return Err(crate::error::Parse::TooLarge.into());
},
Ok(r) => r,
}
};
($op:expr, error) => {
match $op {
Err(e) if e.is::<http::CapacityOverflow>() => {
return Err(crate::error::Parse::TooLarge.into());
},
Err(e) if e.is::<http::header::InvalidHeaderName>() => {
// FIXME!
return Err(crate::error::Parse::Internal.into());
},
Err(_e) => {
// FIXME!
return Err(crate::error::Parse::Internal.into());
},
Ok(r) => r,
}
};
}


#[cfg(test)]
mod tests {
use super::*;
Expand Down
12 changes: 7 additions & 5 deletions src/ext.rs
Expand Up @@ -5,7 +5,7 @@ use bytes::Bytes;
use http::header::HeaderName;
#[cfg(feature = "http1")]
use http::header::{IntoHeaderName, ValueIter};
use http::HeaderMap;
use http::{CapacityOverflow, HeaderMap};
#[cfg(feature = "ffi")]
use std::collections::HashMap;
#[cfg(feature = "http2")]
Expand Down Expand Up @@ -119,15 +119,17 @@ impl HeaderCaseMap {
}

#[cfg(any(test, feature = "ffi"))]
pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) {
self.0.insert(name, orig);
pub(crate) fn try_insert(&mut self, name: HeaderName, orig: Bytes) -> Result<(), CapacityOverflow> {
self.0.try_insert(name, orig)?;
Ok(())
}

pub(crate) fn append<N>(&mut self, name: N, orig: Bytes)
pub(crate) fn try_append<N>(&mut self, name: N, orig: Bytes) -> Result<(), CapacityOverflow>
where
N: IntoHeaderName,
{
self.0.append(name, orig);
self.0.try_append(name, orig)?;
Ok(())
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/ffi/error.rs
Expand Up @@ -24,6 +24,8 @@ pub enum hyper_code {
HYPERE_FEATURE_NOT_ENABLED,
/// The peer sent an HTTP message that could not be parsed.
HYPERE_INVALID_PEER_MESSAGE,
/// HTTP headers reached the maximum capacity
HYPERE_HEADERS_CAPACITY,
}

// ===== impl hyper_error =====
Expand Down
28 changes: 24 additions & 4 deletions src/ffi/http_types.rs
Expand Up @@ -468,8 +468,18 @@ ffi_fn! {
let headers = non_null!(&mut *headers ?= hyper_code::HYPERE_INVALID_ARG);
match unsafe { raw_name_value(name, name_len, value, value_len) } {
Ok((name, value, orig_name)) => {
headers.headers.insert(&name, value);
headers.orig_casing.insert(name.clone(), orig_name.clone());
match headers.headers.try_insert(&name, value) {
Err(http::CapacityOverflow { .. }) => {
return hyper_code::HYPERE_HEADERS_CAPACITY;
}
Ok(_) => {},
}
match headers.orig_casing.try_insert(name.clone(), orig_name.clone()) {
Err(http::CapacityOverflow { .. }) => {
return hyper_code::HYPERE_HEADERS_CAPACITY;
}
Ok(()) => {},
}
headers.orig_order.insert(name);
hyper_code::HYPERE_OK
}
Expand All @@ -488,8 +498,18 @@ ffi_fn! {

match unsafe { raw_name_value(name, name_len, value, value_len) } {
Ok((name, value, orig_name)) => {
headers.headers.append(&name, value);
headers.orig_casing.append(&name, orig_name.clone());
match headers.headers.try_append(&name, value) {
Err(http::CapacityOverflow { .. }) => {
return hyper_code::HYPERE_HEADERS_CAPACITY;
}
Ok(_) => {},
}
match headers.orig_casing.try_append(&name, orig_name.clone()) {
Err(http::CapacityOverflow { .. }) => {
return hyper_code::HYPERE_HEADERS_CAPACITY;
}
Ok(()) => {},
}
headers.orig_order.append(name);
hyper_code::HYPERE_OK
}
Expand Down
18 changes: 13 additions & 5 deletions src/headers.rs
Expand Up @@ -2,7 +2,7 @@
use bytes::BytesMut;
use http::header::CONTENT_LENGTH;
use http::header::{HeaderValue, ValueIter};
use http::HeaderMap;
use http::{CapacityOverflow, HeaderMap};
#[cfg(all(feature = "http2", feature = "client"))]
use http::Method;

Expand Down Expand Up @@ -100,10 +100,18 @@ pub(super) fn method_has_defined_payload_semantics(method: &Method) -> bool {
}

#[cfg(feature = "http2")]
pub(super) fn set_content_length_if_missing(headers: &mut HeaderMap, len: u64) {
headers
.entry(CONTENT_LENGTH)
.or_insert_with(|| HeaderValue::from(len));
pub(super) fn set_content_length_if_missing(headers: &mut HeaderMap, len: u64) -> Result<(), CapacityOverflow> {
match headers
.try_entry(CONTENT_LENGTH){
Err(e) if e.is::<CapacityOverflow>() => return Err(CapacityOverflow::new()),
Err(_) => {
unreachable!()
},
Ok(e) => {
e.or_try_insert_with(|| HeaderValue::from(len))?;
}
};
Ok(())
}

#[cfg(feature = "http1")]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Expand Up @@ -77,6 +77,7 @@ mod cfg;
#[macro_use]
mod common;
pub mod body;
#[macro_use]
mod error;
pub mod ext;
#[cfg(test)]
Expand Down
20 changes: 14 additions & 6 deletions src/proto/h1/conn.rs
Expand Up @@ -551,7 +551,11 @@ where
self.state.busy();
}

self.enforce_version(&mut head);
if let Err(err) = self.enforce_version(&mut head) {
self.state.error = Some(err.into());
self.state.writing = Writing::Closed;
return None;
}

let buf = self.io.headers_buf();
match super::role::encode_headers::<T>(
Expand Down Expand Up @@ -587,7 +591,7 @@ where
}

// Fix keep-alive when Connection: keep-alive header is not present
fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) {
fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) -> Result<(), http::Error> {
let outgoing_is_keep_alive = head
.headers
.get(CONNECTION)
Expand All @@ -604,27 +608,31 @@ where
Version::HTTP_11 => {
if self.state.wants_keep_alive() {
head.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
.try_insert(CONNECTION, HeaderValue::from_static("keep-alive"))?;
}
}
_ => (),
}
}

Ok(())
}

// If we know the remote speaks an older version, we try to fix up any messages
// to work with our older peer.
fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) {
fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) -> Result<(), http::Error> {
if let Version::HTTP_10 = self.state.version {
// Fixes response or connection when keep-alive header is not present
self.fix_keep_alive(head);
self.fix_keep_alive(head)?;
// If the remote only knows HTTP/1.0, we should force ourselves
// to do only speak HTTP/1.0 as well.
head.version = Version::HTTP_10;
}
// If the remote speaks HTTP/1.1, then it *should* be fine with
// both HTTP/1.0 and HTTP/1.1 from us. So again, we just let
// the user's headers be.

Ok(())
}

pub(crate) fn write_body(&mut self, chunk: B) {
Expand Down Expand Up @@ -1061,7 +1069,7 @@ mod tests {
let io = tokio_test::io::Builder::new().build();
let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io);
*conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
conn.state.cached_headers = Some(HeaderMap::try_with_capacity(2).unwrap());

let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down