diff --git a/Cargo.toml b/Cargo.toml index 4a94f06aa9..93624a1ca6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ http = "0.2" http-body = "0.4" httpdate = "1.0" httparse = "1.4" -h2 = { version = "0.3", optional = true } +h2 = { version = "0.3.3", optional = true } itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["std"] } pin-project = "1.0" diff --git a/src/body/length.rs b/src/body/length.rs index aa9cf3dcd5..633a911fb2 100644 --- a/src/body/length.rs +++ b/src/body/length.rs @@ -3,6 +3,17 @@ use std::fmt; #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) struct DecodedLength(u64); +#[cfg(any(feature = "http1", feature = "http2"))] +impl From> for DecodedLength { + fn from(len: Option) -> Self { + len.and_then(|len| { + // If the length is u64::MAX, oh well, just reported chunked. + Self::checked_new(len).ok() + }) + .unwrap_or(DecodedLength::CHUNKED) + } +} + #[cfg(any(feature = "http1", feature = "http2", test))] const MAX_LEN: u64 = std::u64::MAX - 2; diff --git a/src/error.rs b/src/error.rs index 663156e0a9..42029b129e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -90,7 +90,7 @@ pub(super) enum User { /// User tried to send a certain header in an unexpected context. /// /// For example, sending both `content-length` and `transfer-encoding`. - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] UnexpectedHeader, /// User tried to create a Request with bad version. @@ -279,7 +279,7 @@ impl Error { Error::new(Kind::User(user)) } - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] pub(super) fn new_user_header() -> Error { Error::new_user(User::UnexpectedHeader) @@ -394,7 +394,7 @@ impl Error { Kind::User(User::MakeService) => "error from user's MakeService", #[cfg(any(feature = "http1", feature = "http2"))] Kind::User(User::Service) => "error from user's Service", - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] Kind::User(User::UnexpectedHeader) => "user sent unexpected header", #[cfg(any(feature = "http1", feature = "http2"))] diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index cf06592903..688a5c7464 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,5 +1,5 @@ -use bytes::Buf; -use h2::SendStream; +use bytes::{Buf, Bytes}; +use h2::{RecvStream, SendStream}; use http::header::{ HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, @@ -7,11 +7,15 @@ use http::header::{ use http::HeaderMap; use pin_project::pin_project; use std::error::Error as StdError; -use std::io::IoSlice; +use std::io::{self, Cursor, IoSlice}; +use std::mem; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::body::{DecodedLength, HttpBody}; use crate::common::{task, Future, Pin, Poll}; use crate::headers::content_length_parse_all; +use crate::proto::h2::ping::Recorder; pub(crate) mod ping; @@ -84,12 +88,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { } fn decode_content_length(headers: &HeaderMap) -> DecodedLength { - if let Some(len) = content_length_parse_all(headers) { - // If the length is u64::MAX, oh well, just reported chunked. - DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) - } else { - DecodedLength::CHUNKED - } + content_length_parse_all(headers).into() } // body adapters used by both Client and Server @@ -172,7 +171,7 @@ where is_eos, ); - let buf = SendBuf(Some(chunk)); + let buf = SendBuf::Buf(chunk); me.body_tx .send_data(buf, is_eos) .map_err(crate::Error::new_body_write)?; @@ -243,32 +242,201 @@ impl SendStreamExt for SendStream> { fn send_eos_frame(&mut self) -> crate::Result<()> { trace!("send body eos"); - self.send_data(SendBuf(None), true) + self.send_data(SendBuf::None, true) .map_err(crate::Error::new_body_write) } } -struct SendBuf(Option); +enum SendBuf { + Buf(B), + Cursor(Cursor>), + None, +} impl Buf for SendBuf { #[inline] fn remaining(&self) -> usize { - self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => c.remaining(), + Self::None => 0, + } } #[inline] fn chunk(&self) -> &[u8] { - self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[]) + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } } #[inline] fn advance(&mut self, cnt: usize) { - if let Some(b) = self.0.as_mut() { - b.advance(cnt) + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {} } } fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { - self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded +where + B: Buf, +{ + ping: Recorder, + send_stream: UpgradedSendStream, + recv_stream: RecvStream, + buf: Bytes, +} + +impl AsyncRead for H2Upgraded +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(Err(h2_to_io_error(e))); + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for H2Upgraded +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Poll::Ready(reset) = self.send_stream.poll_reset(cx) { + return Poll::Ready(Err(h2_to_io_error(match reset { + Ok(reason) => reason.into(), + Err(e) => e, + }))); + } + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.send_stream.reserve_capacity(buf.len()); + Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) { + None => Ok(0), + Some(Ok(cnt)) => self.send_stream.write(&buf[..cnt], false).map(|()| cnt), + Some(Err(e)) => Err(h2_to_io_error(e)), + }) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(self.send_stream.write(&[], true)) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + +struct UpgradedSendStream(SendStream>>); + +impl UpgradedSendStream +where + B: Buf, +{ + unsafe fn new(inner: SendStream>) -> Self { + assert_eq!(mem::size_of::(), mem::size_of::>()); + Self(mem::transmute(inner)) + } + + fn reserve_capacity(&mut self, cnt: usize) { + unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } + } + + fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll>> { + unsafe { self.as_inner_unchecked().poll_capacity(cx) } + } + + fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll> { + unsafe { self.as_inner_unchecked().poll_reset(cx) } + } + + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + unsafe { + self.as_inner_unchecked() + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } + } + + unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream> { + &mut *(&mut self.0 as *mut _ as *mut _) + } +} + +#[repr(transparent)] +struct Neutered { + _inner: B, + impossible: Impossible, +} + +enum Impossible {} + +unsafe impl Send for Neutered {} + +impl Buf for Neutered { + fn remaining(&self) -> usize { + match self.impossible {} + } + + fn chunk(&self) -> &[u8] { + match self.impossible {} + } + + fn advance(&mut self, _cnt: usize) { + match self.impossible {} } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 16e6a1af3e..de77eaa232 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,19 +3,24 @@ use std::marker::Unpin; #[cfg(feature = "runtime")] use std::time::Duration; +use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; -use h2::Reason; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use super::{ping, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::exec::ConnStreamExec; use crate::common::{date, task, Future, Pin, Poll}; use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; use crate::proto::Dispatched; use crate::service::HttpService; +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; use crate::{Body, Response}; // Our defaults are chosen for the "majority" case, which usually are not @@ -255,9 +260,9 @@ where // When the service is ready, accepts an incoming request. match ready!(self.conn.poll_accept(cx)) { - Some(Ok((req, respond))) => { + Some(Ok((req, mut respond))) => { trace!("incoming request"); - let content_length = decode_content_length(req.headers()); + let content_length = headers::content_length_parse_all(req.headers()); let ping = self .ping .as_ref() @@ -267,8 +272,36 @@ where // Record the headers received ping.record_non_data(); - let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); - let fut = H2Stream::new(service.call(req), respond); + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length.into(), ping), + ), + None, + ) + } else { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect request with non-zero body not supported"); + respond.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Ok(())); + } + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some(ConnectParts { + pending, + ping, + recv_stream: stream, + }), + ) + }; + + let fut = H2Stream::new(service.call(req), connect_parts, respond); exec.execute_h2stream(fut); } Some(Err(e)) => { @@ -331,18 +364,28 @@ enum H2StreamState where B: HttpBody, { - Service(#[pin] F), + Service(#[pin] F, Option), Body(#[pin] PipeToSendStream), } +struct ConnectParts { + pending: Pending, + ping: Recorder, + recv_stream: RecvStream, +} + impl H2Stream where B: HttpBody, { - fn new(fut: F, respond: SendResponse>) -> H2Stream { + fn new( + fut: F, + connect_parts: Option, + respond: SendResponse>, + ) -> H2Stream { H2Stream { reply: respond, - state: H2StreamState::Service(fut), + state: H2StreamState::Service(fut, connect_parts), } } } @@ -372,7 +415,7 @@ where let mut me = self.project(); loop { let next = match me.state.as_mut().project() { - H2StreamStateProj::Service(h) => { + H2StreamStateProj::Service(h, connect_parts) => { let res = match h.poll(cx) { Poll::Ready(Ok(r)) => r, Poll::Pending => { @@ -403,6 +446,29 @@ where .entry(::http::header::DATE) .or_insert_with(date::update_and_header_value); + if let Some(connect_parts) = connect_parts.take() { + if res.status().is_success() { + if headers::content_length_parse_all(res.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 successful response to CONNECT request with body not supported"); + me.reply.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_user_header())); + } + let send_stream = reply!(me, res, false); + connect_parts.pending.fulfill(Upgraded::new( + H2Upgraded { + ping: connect_parts.ping, + recv_stream: connect_parts.recv_stream, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + } + // automatically set Content-Length from body... if let Some(len) = body.size_hint().exact() { headers::set_content_length_if_missing(res.headers_mut(), len); diff --git a/src/upgrade.rs b/src/upgrade.rs index 6004c1a31a..efab10a6fc 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -62,12 +62,12 @@ pub fn on(msg: T) -> OnUpgrade { msg.on_upgrade() } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) struct Pending { tx: oneshot::Sender>, } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); (Pending { tx }, OnUpgrade { rx: Some(rx) }) @@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) { // ===== impl Upgraded ===== impl Upgraded { - #[cfg(any(feature = "http1", test))] + #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade { // ===== impl Pending ===== -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] impl Pending { pub(super) fn fulfill(self, upgraded: Upgraded) { trace!("pending upgrade fulfill"); let _ = self.tx.send(Ok(upgraded)); } + #[cfg(feature = "http1")] /// Don't fulfill the pending Upgrade, but instead signal that /// upgrades are handled manually. pub(super) fn manual(self) { diff --git a/tests/server.rs b/tests/server.rs index 662e903d57..297b09ac73 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -13,10 +13,13 @@ use std::task::{Context, Poll}; use std::thread; use std::time::Duration; +use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, Either, FutureExt, TryFutureExt}; #[cfg(feature = "stream")] use futures_util::stream::StreamExt as _; +use h2::client::SendRequest; +use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::{TcpListener, TcpStream as TkTcpStream}; @@ -1482,6 +1485,339 @@ async fn http_connect_new() { assert_eq!(s(&vec), "bar=foo"); } +#[tokio::test] +async fn h2_connect() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_multiplex() { + use futures_util::stream::FuturesUnordered; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + tokio::spawn(async move { + let mut streams = vec![]; + for i in 0..80 { + let request = Request::connect(format!("localhost_{}", i % 4)) + .body(()) + .unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + streams.push((i, response, send_stream)); + } + + let futures = streams + .into_iter() + .map(|(i, response, mut send_stream)| async move { + if i % 4 == 0 { + return; + } + + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + if i % 4 == 1 { + return; + } + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + if i % 4 == 2 { + return; + } + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + assert!(body.data().await.unwrap().unwrap().is_empty()); + }) + .collect::>(); + + futures.for_each(future::ready).await; + }); + + let svc = service_fn(move |req: Request| { + let authority = req.uri().authority().unwrap().to_string(); + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let upgrade_res = on_upgrade.await; + if authority == "localhost_0" { + assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled()); + return; + } + let mut upgraded = upgrade_res.expect("upgrade successful"); + + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + let read_res = upgraded.read_to_end(&mut vec).await; + + if authority == "localhost_1" || authority == "localhost_2" { + let err = read_res.expect_err("read failed"); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!( + err.get_ref() + .unwrap() + .downcast_ref::() + .unwrap() + .reason(), + Some(h2::Reason::CANCEL), + ); + return; + } + + read_res.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_large_body() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const NO_BREAD: &str = "All work and no bread makes nox a dull boy.\n"; + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + let large_body = Bytes::from(NO_BREAD.repeat(9000)); + + send_stream.send_data(large_body.clone(), false).unwrap(); + send_stream.send_data(large_body, true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + if upgraded.read_to_end(&mut vec).await.is_err() { + return; + } + assert_eq!(vec.len(), NO_BREAD.len() * 9000 * 2); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_empty_frames() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("Baguette!".into(), false).unwrap(); + send_stream.send_data("".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();