Skip to content

Commit

Permalink
feat(h2): implement CONNECT support (fixes hyperium#2508)
Browse files Browse the repository at this point in the history
  • Loading branch information
nox committed May 6, 2021
1 parent a4114eb commit a2385ae
Show file tree
Hide file tree
Showing 7 changed files with 565 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -31,7 +31,7 @@ http = "0.2"
http-body = "0.4"
httpdate = "1.0"
httparse = "1.4"
h2 = { version = "0.3", optional = true }
h2 = { version = "0.3.3", optional = true }
itoa = "0.4.1"
tracing = { version = "0.1", default-features = false, features = ["std"] }
pin-project = "1.0"
Expand Down
11 changes: 11 additions & 0 deletions src/body/length.rs
Expand Up @@ -3,6 +3,17 @@ use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct DecodedLength(u64);

#[cfg(any(feature = "http1", feature = "http2"))]
impl From<Option<u64>> for DecodedLength {
fn from(len: Option<u64>) -> Self {
len.and_then(|len| {
// If the length is u64::MAX, oh well, just reported chunked.
Self::checked_new(len).ok()
})
.unwrap_or(DecodedLength::CHUNKED)
}
}

#[cfg(any(feature = "http1", feature = "http2", test))]
const MAX_LEN: u64 = std::u64::MAX - 2;

Expand Down
6 changes: 3 additions & 3 deletions src/error.rs
Expand Up @@ -90,7 +90,7 @@ pub(super) enum User {
/// User tried to send a certain header in an unexpected context.
///
/// For example, sending both `content-length` and `transfer-encoding`.
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
UnexpectedHeader,
/// User tried to create a Request with bad version.
Expand Down Expand Up @@ -279,7 +279,7 @@ impl Error {
Error::new(Kind::User(user))
}

#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
pub(super) fn new_user_header() -> Error {
Error::new_user(User::UnexpectedHeader)
Expand Down Expand Up @@ -394,7 +394,7 @@ impl Error {
Kind::User(User::MakeService) => "error from user's MakeService",
#[cfg(any(feature = "http1", feature = "http2"))]
Kind::User(User::Service) => "error from user's Service",
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
#[cfg(any(feature = "http1", feature = "http2"))]
Expand Down
150 changes: 133 additions & 17 deletions src/proto/h2/mod.rs
@@ -1,17 +1,20 @@
use bytes::Buf;
use h2::SendStream;
use bytes::{Buf, Bytes};
use h2::{RecvStream, SendStream};
use http::header::{
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
TRANSFER_ENCODING, UPGRADE,
};
use http::HeaderMap;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::error::Error as StdError;
use std::io::IoSlice;
use std::io::{self, Cursor, IoSlice};
use std::task::Context;

use crate::body::{DecodedLength, HttpBody};
use crate::common::{task, Future, Pin, Poll};
use crate::headers::content_length_parse_all;
use crate::proto::h2::ping::Recorder;

pub(crate) mod ping;

Expand Down Expand Up @@ -84,12 +87,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
}

fn decode_content_length(headers: &HeaderMap) -> DecodedLength {
if let Some(len) = content_length_parse_all(headers) {
// If the length is u64::MAX, oh well, just reported chunked.
DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED)
} else {
DecodedLength::CHUNKED
}
content_length_parse_all(headers).into()
}

// body adapters used by both Client and Server
Expand Down Expand Up @@ -172,7 +170,7 @@ where
is_eos,
);

let buf = SendBuf(Some(chunk));
let buf = SendBuf::Buf(chunk);
me.body_tx
.send_data(buf, is_eos)
.map_err(crate::Error::new_body_write)?;
Expand Down Expand Up @@ -243,32 +241,150 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {

fn send_eos_frame(&mut self) -> crate::Result<()> {
trace!("send body eos");
self.send_data(SendBuf(None), true)
self.send_data(SendBuf::None, true)
.map_err(crate::Error::new_body_write)
}
}

struct SendBuf<B>(Option<B>);
enum SendBuf<B> {
Buf(B),
Cursor(Cursor<Box<[u8]>>),
None,
}

impl<B: Buf> Buf for SendBuf<B> {
#[inline]
fn remaining(&self) -> usize {
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.remaining(),
Self::Cursor(ref c) => c.remaining(),
Self::None => 0,
}
}

#[inline]
fn chunk(&self) -> &[u8] {
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
match *self {
Self::Buf(ref b) => b.chunk(),
Self::Cursor(ref c) => c.chunk(),
Self::None => &[],
}
}

#[inline]
fn advance(&mut self, cnt: usize) {
if let Some(b) = self.0.as_mut() {
b.advance(cnt)
match *self {
Self::Buf(ref mut b) => b.advance(cnt),
Self::Cursor(ref mut c) => c.advance(cnt),
Self::None => {},
}
}

fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.chunks_vectored(dst),
Self::Cursor(ref c) => c.chunks_vectored(dst),
Self::None => 0,
}
}
}

struct H2Upgraded<B>
where
B: Buf,
{
ping: Recorder,
send_stream: SendStream<SendBuf<B>>,
recv_stream: RecvStream,
buf: Bytes,
}

impl<B> AsyncRead for H2Upgraded<B>
where
B: Buf,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.buf = loop {
match ready!(self.recv_stream.poll_data(cx)) {
None => return Poll::Ready(Ok(())),
Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => continue,
Some(Ok(buf)) => {
self.ping.record_data(buf.len());
break buf;
}
Some(Err(e)) => {
return Poll::Ready(Err(h2_to_io_error(e)));
}
}
};
}
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
read_buf.put_slice(&self.buf[..cnt]);
self.buf.advance(cnt);
let _ = self.recv_stream.flow_control().release_capacity(cnt);
Poll::Ready(Ok(()))
}
}

impl<B> AsyncWrite for H2Upgraded<B>
where
B: Buf,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if let Poll::Ready(reset) = self.send_stream.poll_reset(cx) {
return Poll::Ready(Err(h2_to_io_error(match reset {
Ok(reason) => reason.into(),
Err(e) => e,
})));
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
self.send_stream.reserve_capacity(buf.len());
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
None => Ok(0),
Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt),
Some(Err(e)) => Err(h2_to_io_error(e)),
})
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(self.write(&[], true))
}
}

impl<B> H2Upgraded<B>
where
B: Buf,
{
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
self.send_stream
.send_data(send_buf, end_of_stream)
.map_err(h2_to_io_error)
}
}

fn h2_to_io_error(e: h2::Error) -> io::Error {
if e.is_io() {
e.into_io().unwrap()
} else {
io::Error::new(io::ErrorKind::Other, e)
}
}

0 comments on commit a2385ae

Please sign in to comment.