From 336e52639fed6187dd98c05a190cb04dcb0ff15f Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 4 Nov 2021 10:18:06 +0100 Subject: [PATCH] fix(h2): don't use h2::ext::Protocol publicly --- Cargo.toml | 3 --- src/ext.rs | 57 +++++++++++++++++++++++++++++++++++++++++- src/proto/h2/client.rs | 5 ++++ src/proto/h2/server.rs | 7 +++++- 4 files changed, 67 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84d5594b16..8941140742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -253,6 +253,3 @@ required-features = ["full"] name = "server" path = "tests/server.rs" required-features = ["full"] - -[patch.crates-io] -h2 = { git = "https://github.com/hyperium/h2.git", branch = "rfc8441" } diff --git a/src/ext.rs b/src/ext.rs index 8885417a0e..e9d4587784 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -4,9 +4,64 @@ use bytes::Bytes; #[cfg(feature = "http1")] use http::header::{HeaderName, IntoHeaderName, ValueIter}; use http::HeaderMap; +#[cfg(feature = "http2")] +use std::fmt; #[cfg(feature = "http2")] -pub use h2::ext::Protocol; +/// Represents the `:protocol` pseudo-header used by +/// the [Extended CONNECT Protocol]. +/// +/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 +#[derive(Clone, Eq, PartialEq)] +pub struct Protocol { + inner: h2::ext::Protocol, +} + +#[cfg(feature = "http2")] +impl Protocol { + /// Converts a static string to a protocol name. + pub const fn from_static(value: &'static str) -> Self { + Self { + inner: h2::ext::Protocol::from_static(value), + } + } + + /// Returns a str representation of the header. + pub fn as_str(&self) -> &str { + self.inner.as_str() + } + + pub(crate) fn from_inner(inner: h2::ext::Protocol) -> Self { + Self { inner } + } + + pub(crate) fn into_inner(self) -> h2::ext::Protocol { + self.inner + } +} + +#[cfg(feature = "http2")] +impl<'a> From<&'a str> for Protocol { + fn from(value: &'a str) -> Self { + Self { + inner: h2::ext::Protocol::from(value), + } + } +} + +#[cfg(feature = "http2")] +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} + +#[cfg(feature = "http2")] +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} /// A map from header names to their original casing as received in an HTTP message. /// diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 0e78c09b2c..013f6fb5a8 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -14,6 +14,7 @@ use tracing::{debug, trace, warn}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; +use crate::ext::Protocol; use crate::headers; use crate::proto::h2::UpgradedSendStream; use crate::proto::Dispatched; @@ -269,6 +270,10 @@ where } } + if let Some(protocol) = req.extensions_mut().remove::() { + req.extensions_mut().insert(protocol.into_inner()); + } + let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) { Ok(ok) => ok, Err(err) => { diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 3278df8e46..b9037ee3dd 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -15,6 +15,7 @@ use super::{ping, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::exec::ConnStreamExec; use crate::common::{date, task, Future, Pin, Poll}; +use crate::ext::Protocol; use crate::headers; use crate::proto::h2::ping::Recorder; use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; @@ -285,7 +286,7 @@ where let is_connect = req.method() == Method::CONNECT; let (mut parts, stream) = req.into_parts(); - let (req, connect_parts) = if !is_connect { + let (mut req, connect_parts) = if !is_connect { ( Request::from_parts( parts, @@ -312,6 +313,10 @@ where ) }; + if let Some(protocol) = req.extensions_mut().remove::() { + req.extensions_mut().insert(Protocol::from_inner(protocol)); + } + let fut = H2Stream::new(service.call(req), connect_parts, respond); exec.execute_h2stream(fut); }