Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CONNECT support over HTTP/2 #2523

Merged
merged 3 commits into from May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -31,7 +31,7 @@ http = "0.2"
http-body = "0.4"
httpdate = "1.0"
httparse = "1.4"
h2 = { version = "0.3", optional = true }
h2 = { version = "0.3.3", optional = true }
itoa = "0.4.1"
tracing = { version = "0.1", default-features = false, features = ["std"] }
pin-project = "1.0"
Expand Down
11 changes: 11 additions & 0 deletions src/body/length.rs
Expand Up @@ -3,6 +3,17 @@ use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct DecodedLength(u64);

#[cfg(any(feature = "http1", feature = "http2"))]
impl From<Option<u64>> for DecodedLength {
fn from(len: Option<u64>) -> Self {
len.and_then(|len| {
// If the length is u64::MAX, oh well, just reported chunked.
Self::checked_new(len).ok()
})
.unwrap_or(DecodedLength::CHUNKED)
}
}

#[cfg(any(feature = "http1", feature = "http2", test))]
const MAX_LEN: u64 = std::u64::MAX - 2;

Expand Down
7 changes: 2 additions & 5 deletions src/client/client.rs
Expand Up @@ -254,12 +254,9 @@ where
absolute_form(req.uri_mut());
} else {
origin_form(req.uri_mut());
};
}
} else if req.method() == Method::CONNECT {
debug!("client does not support CONNECT requests over HTTP2");
return Err(ClientError::Normal(
crate::Error::new_user_unsupported_request_method(),
));
authority_form(req.uri_mut());
}

let fut = pooled
Expand Down
6 changes: 3 additions & 3 deletions src/error.rs
Expand Up @@ -90,7 +90,7 @@ pub(super) enum User {
/// User tried to send a certain header in an unexpected context.
///
/// For example, sending both `content-length` and `transfer-encoding`.
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
UnexpectedHeader,
/// User tried to create a Request with bad version.
Expand Down Expand Up @@ -290,7 +290,7 @@ impl Error {
Error::new(Kind::User(user))
}

#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
pub(super) fn new_user_header() -> Error {
Error::new_user(User::UnexpectedHeader)
Expand Down Expand Up @@ -405,7 +405,7 @@ impl Error {
Kind::User(User::MakeService) => "error from user's MakeService",
#[cfg(any(feature = "http1", feature = "http2"))]
Kind::User(User::Service) => "error from user's Service",
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
#[cfg(any(feature = "http1", feature = "http2"))]
Expand Down
121 changes: 89 additions & 32 deletions src/proto/h2/client.rs
Expand Up @@ -2,17 +2,21 @@ use std::error::Error as StdError;
#[cfg(feature = "runtime")]
use std::time::Duration;

use bytes::Bytes;
use futures_channel::{mpsc, oneshot};
use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
use futures_util::stream::StreamExt as _;
use h2::client::{Builder, SendRequest};
use http::{Method, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};

use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
use crate::body::HttpBody;
use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
use crate::headers;
use crate::proto::h2::UpgradedSendStream;
use crate::proto::Dispatched;
use crate::upgrade::Upgraded;
use crate::{Body, Request, Response};

type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>;
Expand Down Expand Up @@ -233,8 +237,25 @@ where
headers::set_content_length_if_missing(req.headers_mut(), len);
}
}

let is_connect = req.method() == Method::CONNECT;
let eos = body.is_end_stream();
let (fut, body_tx) = match self.h2_tx.send_request(req, eos) {
let ping = self.ping.clone();

if is_connect {
if headers::content_length_parse_all(req.headers())
.map_or(false, |len| len != 0)
{
warn!("h2 connect request with non-zero body not supported");
cb.send(Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
)));
continue;
}
}

let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) {
Ok(ok) => ok,
Err(err) => {
debug!("client send request error: {}", err);
Expand All @@ -243,45 +264,81 @@ where
}
};

let ping = self.ping.clone();
if !eos {
let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
let send_stream = if !is_connect {
if !eos {
let mut pipe =
Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});
self.executor.execute(pipe);

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
self.executor.execute(pipe);
}
}
}
}

None
} else {
Some(body_tx)
};

let fut = fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();

let content_length = decode_content_length(res.headers());
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length, ping)
});
Ok(res)
let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) =
(send_stream, res.status())
{
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved

send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, Body::empty());

let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());

pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);

Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length.into(), ping)
});
Ok(res)
}
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;
Expand Down