diff --git a/Cargo.toml b/Cargo.toml index 4a94f06aa9..c98d17c4f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ include = [ [lib] crate-type = ["lib", "staticlib", "cdylib"] +[patch.crates-io] +h2 = { git = "https://github.com/hyperium/h2.git", branch = "master" } + [dependencies] bytes = "1" futures-core = { version = "0.3", default-features = false } @@ -31,7 +34,7 @@ http = "0.2" http-body = "0.4" httpdate = "1.0" httparse = "1.4" -h2 = { version = "0.3", optional = true } +h2 = { version = "0.3.2", optional = true } itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["std"] } pin-project = "1.0" diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index cf06592903..a1e6ed3a64 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,17 +1,20 @@ -use bytes::Buf; -use h2::SendStream; +use bytes::{Buf, Bytes}; +use h2::{RecvStream, SendStream}; use http::header::{ HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, }; use http::HeaderMap; use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::error::Error as StdError; -use std::io::IoSlice; +use std::io::{self, Cursor, IoSlice}; +use std::task::Context; use crate::body::{DecodedLength, HttpBody}; use crate::common::{task, Future, Pin, Poll}; use crate::headers::content_length_parse_all; +use crate::proto::h2::ping::Recorder; pub(crate) mod ping; @@ -172,7 +175,7 @@ where is_eos, ); - let buf = SendBuf(Some(chunk)); + let buf = SendBuf::Buf(chunk); me.body_tx .send_data(buf, is_eos) .map_err(crate::Error::new_body_write)?; @@ -243,32 +246,155 @@ impl SendStreamExt for SendStream> { fn send_eos_frame(&mut self) -> crate::Result<()> { trace!("send body eos"); - self.send_data(SendBuf(None), true) + self.send_data(SendBuf::None, true) .map_err(crate::Error::new_body_write) } } -struct SendBuf(Option); +enum SendBuf { + Buf(B), + Cursor(Cursor>), + None, +} impl Buf for SendBuf { #[inline] fn remaining(&self) -> usize { - self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => c.remaining(), + Self::None => 0, + } } #[inline] fn chunk(&self) -> &[u8] { - self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[]) + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } } #[inline] fn advance(&mut self, cnt: usize) { - if let Some(b) = self.0.as_mut() { - b.advance(cnt) + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {}, } } fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { - self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +// FIXME(nox): Should this type be public? I'm asking this because +// the HTTP/2 RFC says that a proxy that encounters a TCP error with the +// upstream peer should send back to the client a stream error with reason +// CONNECT_ERROR, so we need *something* to send that, but all the user +// gets is a hyper::upgrade::Upgraded, so you can't send anything but a +// data frame back. +struct H2Upgraded +where + B: Buf, +{ + ping: Recorder, + send_stream: SendStream>, + recv_stream: RecvStream, + buf: Bytes, +} + +impl AsyncRead for H2Upgraded +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + buf + } + Some(Err(e)) => { + return Poll::Ready(Err(h2_to_io_error(e))); + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for H2Upgraded +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + // FIXME(nox): PipeToSendStream does some weird stuff, first reserving + // one byte and then polling reset if the capacity is 0, should we do + // that here too? Should we poll reset somewhere? + self.send_stream.reserve_capacity(buf.len()); + Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) { + None => Ok(0), + Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt), + Some(Err(e)) => { + // FIXME(nox): Should all H2 errors be returned as is with a + // ErrorKind::Other, or should some be special-cased, say for + // example, CANCEL? + Err(h2_to_io_error(e)) + }, + }) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(self.write(&[], true)) + } +} + +impl H2Upgraded +where + B: Buf, +{ + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + self.send_stream + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index d34802b727..bf1b36de9b 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,8 +3,10 @@ use std::marker::Unpin; #[cfg(feature = "runtime")] use std::time::Duration; +use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; -use h2::Reason; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; @@ -13,9 +15,12 @@ use crate::body::HttpBody; use crate::common::exec::ConnStreamExec; use crate::common::{date, task, Future, Pin, Poll}; use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::H2Upgraded; use crate::proto::Dispatched; use crate::service::HttpService; +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; use crate::{Body, Response}; // Our defaults are chosen for the "majority" case, which usually are not @@ -269,8 +274,28 @@ where // Record the headers received ping.record_non_data(); - let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); - let fut = H2Stream::new(service.call(req), respond); + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length, ping), + ), + None, + ) + } else { + // FIXME(nox): What happens to the request body? Should we check `content_length`? + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some((pending, ping, stream)), + ) + }; + + let fut = H2Stream::new(service.call(req), connect_parts, respond); exec.execute_h2stream(fut); } Some(Err(e)) => { @@ -333,7 +358,7 @@ enum H2StreamState where B: HttpBody, { - Service(#[pin] F), + Service(#[pin] F, Option<(Pending, Recorder, RecvStream)>), Body(#[pin] PipeToSendStream), } @@ -341,10 +366,14 @@ impl H2Stream where B: HttpBody, { - fn new(fut: F, respond: SendResponse>) -> H2Stream { + fn new( + fut: F, + connect_parts: Option<(Pending, Recorder, RecvStream)>, + respond: SendResponse>, + ) -> H2Stream { H2Stream { reply: respond, - state: H2StreamState::Service(fut), + state: H2StreamState::Service(fut, connect_parts), } } } @@ -374,7 +403,7 @@ where let mut me = self.project(); loop { let next = match me.state.as_mut().project() { - H2StreamStateProj::Service(h) => { + H2StreamStateProj::Service(h, connect_parts) => { let res = match h.poll(cx) { Poll::Ready(Ok(r)) => r, Poll::Pending => { @@ -405,6 +434,21 @@ where .entry(::http::header::DATE) .or_insert_with(date::update_and_header_value); + if let Some((pending, ping, recv_stream)) = connect_parts.take() { + // FIXME(nox): What do we do about the response body? AFAIK h1 returns an error. + let send_stream = reply!(me, res, false); + pending.fulfill(Upgraded::new( + H2Upgraded { + ping, + recv_stream, + send_stream, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + // automatically set Content-Length from body... if let Some(len) = body.size_hint().exact() { headers::set_content_length_if_missing(res.headers_mut(), len); diff --git a/src/upgrade.rs b/src/upgrade.rs index 6004c1a31a..aaeed2b39d 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -62,12 +62,12 @@ pub fn on(msg: T) -> OnUpgrade { msg.on_upgrade() } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) struct Pending { tx: oneshot::Sender>, } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); (Pending { tx }, OnUpgrade { rx: Some(rx) }) @@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade { // ===== impl Pending ===== -#[cfg(feature = "http1")] impl Pending { + #[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn fulfill(self, upgraded: Upgraded) { trace!("pending upgrade fulfill"); let _ = self.tx.send(Ok(upgraded)); } + #[cfg(feature = "http1")] /// Don't fulfill the pending Upgrade, but instead signal that /// upgrades are handled manually. pub(super) fn manual(self) { diff --git a/tests/server.rs b/tests/server.rs index 662e903d57..12c3ecc8f3 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1482,6 +1482,72 @@ async fn http_connect_new() { assert_eq!(s(&vec), "bar=foo"); } +#[tokio::test] +async fn h2_connect() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel(10); + + let conn = connect_async(addr).await; + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + dbg!(connection.await).unwrap(); + }); + + tokio::spawn(async move { + let mut h2 = h2.ready().await.unwrap(); + + let request = Request::connect("localhost").body(()).unwrap(); + let (response, mut send_stream) = h2.send_request(request, false).unwrap(); + let mut response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let bytes = response.body_mut().data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = response + .body_mut() + .flow_control() + .release_capacity(bytes.len()); + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + done_rx.recv().await.unwrap(); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + let done_tx = done_tx.clone(); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + done_tx.send(()).await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();