From 46cc7426a2737e8b9b38f57e8dd8db8df761a846 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Mon, 26 Apr 2021 13:15:15 +0200 Subject: [PATCH] feat(h2): implement CONNECT support (fixes #2508) --- Cargo.toml | 5 +- src/body/length.rs | 11 ++ src/error.rs | 6 +- src/proto/h2/mod.rs | 152 ++++++++++++++++++--- src/proto/h2/server.rs | 74 ++++++++-- src/upgrade.rs | 9 +- tests/server.rs | 300 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 522 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4a94f06aa9..c98d17c4f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ include = [ [lib] crate-type = ["lib", "staticlib", "cdylib"] +[patch.crates-io] +h2 = { git = "https://github.com/hyperium/h2.git", branch = "master" } + [dependencies] bytes = "1" futures-core = { version = "0.3", default-features = false } @@ -31,7 +34,7 @@ http = "0.2" http-body = "0.4" httpdate = "1.0" httparse = "1.4" -h2 = { version = "0.3", optional = true } +h2 = { version = "0.3.2", optional = true } itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["std"] } pin-project = "1.0" diff --git a/src/body/length.rs b/src/body/length.rs index aa9cf3dcd5..633a911fb2 100644 --- a/src/body/length.rs +++ b/src/body/length.rs @@ -3,6 +3,17 @@ use std::fmt; #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) struct DecodedLength(u64); +#[cfg(any(feature = "http1", feature = "http2"))] +impl From> for DecodedLength { + fn from(len: Option) -> Self { + len.and_then(|len| { + // If the length is u64::MAX, oh well, just reported chunked. + Self::checked_new(len).ok() + }) + .unwrap_or(DecodedLength::CHUNKED) + } +} + #[cfg(any(feature = "http1", feature = "http2", test))] const MAX_LEN: u64 = std::u64::MAX - 2; diff --git a/src/error.rs b/src/error.rs index 663156e0a9..42029b129e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -90,7 +90,7 @@ pub(super) enum User { /// User tried to send a certain header in an unexpected context. /// /// For example, sending both `content-length` and `transfer-encoding`. - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] UnexpectedHeader, /// User tried to create a Request with bad version. @@ -279,7 +279,7 @@ impl Error { Error::new(Kind::User(user)) } - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] pub(super) fn new_user_header() -> Error { Error::new_user(User::UnexpectedHeader) @@ -394,7 +394,7 @@ impl Error { Kind::User(User::MakeService) => "error from user's MakeService", #[cfg(any(feature = "http1", feature = "http2"))] Kind::User(User::Service) => "error from user's Service", - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] Kind::User(User::UnexpectedHeader) => "user sent unexpected header", #[cfg(any(feature = "http1", feature = "http2"))] diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index cf06592903..7e71017e80 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,17 +1,20 @@ -use bytes::Buf; -use h2::SendStream; +use bytes::{Buf, Bytes}; +use h2::{RecvStream, SendStream}; use http::header::{ HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, }; use http::HeaderMap; use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::error::Error as StdError; -use std::io::IoSlice; +use std::io::{self, Cursor, IoSlice}; +use std::task::Context; use crate::body::{DecodedLength, HttpBody}; use crate::common::{task, Future, Pin, Poll}; use crate::headers::content_length_parse_all; +use crate::proto::h2::ping::Recorder; pub(crate) mod ping; @@ -84,12 +87,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { } fn decode_content_length(headers: &HeaderMap) -> DecodedLength { - if let Some(len) = content_length_parse_all(headers) { - // If the length is u64::MAX, oh well, just reported chunked. - DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) - } else { - DecodedLength::CHUNKED - } + content_length_parse_all(headers).into() } // body adapters used by both Client and Server @@ -172,7 +170,7 @@ where is_eos, ); - let buf = SendBuf(Some(chunk)); + let buf = SendBuf::Buf(chunk); me.body_tx .send_data(buf, is_eos) .map_err(crate::Error::new_body_write)?; @@ -243,32 +241,152 @@ impl SendStreamExt for SendStream> { fn send_eos_frame(&mut self) -> crate::Result<()> { trace!("send body eos"); - self.send_data(SendBuf(None), true) + self.send_data(SendBuf::None, true) .map_err(crate::Error::new_body_write) } } -struct SendBuf(Option); +enum SendBuf { + Buf(B), + Cursor(Cursor>), + None, +} impl Buf for SendBuf { #[inline] fn remaining(&self) -> usize { - self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => c.remaining(), + Self::None => 0, + } } #[inline] fn chunk(&self) -> &[u8] { - self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[]) + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } } #[inline] fn advance(&mut self, cnt: usize) { - if let Some(b) = self.0.as_mut() { - b.advance(cnt) + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {}, } } fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { - self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded +where + B: Buf, +{ + ping: Recorder, + send_stream: SendStream>, + recv_stream: RecvStream, + buf: Bytes, +} + +impl AsyncRead for H2Upgraded +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => continue, + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(Err(h2_to_io_error(e))); + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for H2Upgraded +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + // FIXME(nox): PipeToSendStream does some weird stuff, first reserving + // one byte and then polling reset if the capacity is 0, should we do + // that here too? Should we poll reset somewhere? + self.send_stream.reserve_capacity(buf.len()); + Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) { + None => Ok(0), + Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt), + Some(Err(e)) => { + // FIXME(nox): Should all H2 errors be returned as is with a + // ErrorKind::Other, or should some be special-cased, say for + // example, CANCEL? + Err(h2_to_io_error(e)) + }, + }) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(self.write(&[], true)) + } +} + +impl H2Upgraded +where + B: Buf, +{ + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + self.send_stream + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index d34802b727..811c98e3c7 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,19 +3,24 @@ use std::marker::Unpin; #[cfg(feature = "runtime")] use std::time::Duration; +use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; -use h2::Reason; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use super::{ping, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::exec::ConnStreamExec; use crate::common::{date, task, Future, Pin, Poll}; use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::H2Upgraded; use crate::proto::Dispatched; use crate::service::HttpService; +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; use crate::{Body, Response}; // Our defaults are chosen for the "majority" case, which usually are not @@ -257,9 +262,9 @@ where // When the service is ready, accepts an incoming request. match ready!(self.conn.poll_accept(cx)) { - Some(Ok((req, respond))) => { + Some(Ok((req, mut respond))) => { trace!("incoming request"); - let content_length = decode_content_length(req.headers()); + let content_length = headers::content_length_parse_all(req.headers()); let ping = self .ping .as_ref() @@ -269,8 +274,32 @@ where // Record the headers received ping.record_non_data(); - let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); - let fut = H2Stream::new(service.call(req), respond); + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length.into(), ping), + ), + None, + ) + } else { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect request with non-zero body not supported"); + respond.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Ok(())); + } + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some((pending, ping, stream)), + ) + }; + + let fut = H2Stream::new(service.call(req), connect_parts, respond); exec.execute_h2stream(fut); } Some(Err(e)) => { @@ -333,7 +362,7 @@ enum H2StreamState where B: HttpBody, { - Service(#[pin] F), + Service(#[pin] F, Option<(Pending, Recorder, RecvStream)>), Body(#[pin] PipeToSendStream), } @@ -341,10 +370,14 @@ impl H2Stream where B: HttpBody, { - fn new(fut: F, respond: SendResponse>) -> H2Stream { + fn new( + fut: F, + connect_parts: Option<(Pending, Recorder, RecvStream)>, + respond: SendResponse>, + ) -> H2Stream { H2Stream { reply: respond, - state: H2StreamState::Service(fut), + state: H2StreamState::Service(fut, connect_parts), } } } @@ -374,7 +407,7 @@ where let mut me = self.project(); loop { let next = match me.state.as_mut().project() { - H2StreamStateProj::Service(h) => { + H2StreamStateProj::Service(h, connect_parts) => { let res = match h.poll(cx) { Poll::Ready(Ok(r)) => r, Poll::Pending => { @@ -405,6 +438,27 @@ where .entry(::http::header::DATE) .or_insert_with(date::update_and_header_value); + if let Some((pending, ping, recv_stream)) = connect_parts.take() { + if res.status().is_success() { + if headers::content_length_parse_all(res.headers()).map_or(false, |len| len != 0) { + warn!("h2 successful response to CONNECT request with body not supported"); + me.reply.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_user_header())); + } + let send_stream = reply!(me, res, false); + pending.fulfill(Upgraded::new( + H2Upgraded { + ping, + recv_stream, + send_stream, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + } + // automatically set Content-Length from body... if let Some(len) = body.size_hint().exact() { headers::set_content_length_if_missing(res.headers_mut(), len); diff --git a/src/upgrade.rs b/src/upgrade.rs index 6004c1a31a..efab10a6fc 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -62,12 +62,12 @@ pub fn on(msg: T) -> OnUpgrade { msg.on_upgrade() } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) struct Pending { tx: oneshot::Sender>, } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); (Pending { tx }, OnUpgrade { rx: Some(rx) }) @@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) { // ===== impl Upgraded ===== impl Upgraded { - #[cfg(any(feature = "http1", test))] + #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade { // ===== impl Pending ===== -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] impl Pending { pub(super) fn fulfill(self, upgraded: Upgraded) { trace!("pending upgrade fulfill"); let _ = self.tx.send(Ok(upgraded)); } + #[cfg(feature = "http1")] /// Don't fulfill the pending Upgrade, but instead signal that /// upgrades are handled manually. pub(super) fn manual(self) { diff --git a/tests/server.rs b/tests/server.rs index 662e903d57..5da08adc8d 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -13,10 +13,13 @@ use std::task::{Context, Poll}; use std::thread; use std::time::Duration; +use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, Either, FutureExt, TryFutureExt}; #[cfg(feature = "stream")] use futures_util::stream::StreamExt as _; +use h2::client::SendRequest; +use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::{TcpListener, TcpStream as TkTcpStream}; @@ -1482,6 +1485,303 @@ async fn http_connect_new() { assert_eq!(s(&vec), "bar=foo"); } +#[tokio::test] +async fn h2_connect() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_multiplex() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream_1, mut send_stream_1) = connect_and_recv_bread(&mut h2).await; + + let (_, mut send_stream_2) = connect_and_recv_bread(&mut h2).await; + + send_stream_1.send_data("Baguette!".into(), true).unwrap(); + + send_stream_2.send_reset(h2::Reason::PROTOCOL_ERROR); + + assert!(recv_stream_1.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + if upgraded.read_to_end(&mut vec).await.is_err() { + return; + } + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_large_body() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const NO_BREAD: &str = "All work and no bread makes nox a dull boy.\n"; + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + let large_body = Bytes::from(NO_BREAD.repeat(9000)); + + send_stream.send_data(large_body.clone(), false).unwrap(); + send_stream.send_data(large_body, true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + if upgraded.read_to_end(&mut vec).await.is_err() { + return; + } + assert_eq!(vec.len(), NO_BREAD.len() * 9000 * 2); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_empty_frames() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("Baguette!".into(), false).unwrap(); + send_stream.send_data("".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();