From 188d7b863c2b1e38681f2298cf5653dd27c965c9 Mon Sep 17 00:00:00 2001 From: Luqman Aden Date: Sat, 30 Oct 2021 15:53:58 -0700 Subject: [PATCH] Expose upgrade method on Response to reuse underlying connection after HTTP Upgrade. --- Cargo.toml | 4 +++ src/async_impl/mod.rs | 1 + src/async_impl/response.rs | 2 +- src/async_impl/upgrade.rs | 72 ++++++++++++++++++++++++++++++++++++++ src/error.rs | 6 ++++ tests/upgrade.rs | 51 +++++++++++++++++++++++++++ 6 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 src/async_impl/upgrade.rs create mode 100644 tests/upgrade.rs diff --git a/Cargo.toml b/Cargo.toml index e50bf7a24..589251ab9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -230,3 +230,7 @@ required-features = ["deflate"] name = "multipart" path = "tests/multipart.rs" required-features = ["multipart"] + +# TODO: Remove once https://github.com/hyperium/hyper/pull/2680 gets merged and released +[patch.crates-io] +hyper = { git = "https://github.com/luqmana/hyper.git", branch = "generic-upgrade" } diff --git a/src/async_impl/mod.rs b/src/async_impl/mod.rs index b8bc8aa6f..6fc29fc29 100644 --- a/src/async_impl/mod.rs +++ b/src/async_impl/mod.rs @@ -13,3 +13,4 @@ pub mod decoder; pub mod multipart; pub(crate) mod request; mod response; +mod upgrade; diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 064567889..9b857a11f 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -24,7 +24,7 @@ use crate::response::ResponseUrl; /// A Response to a submitted `Request`. pub struct Response { - res: hyper::Response, + pub(super) res: hyper::Response, // Boxed to save space (11 words to 1 word), and it's not accessed // frequently internally. url: Box, diff --git a/src/async_impl/upgrade.rs b/src/async_impl/upgrade.rs new file mode 100644 index 000000000..5b9cf12a5 --- /dev/null +++ b/src/async_impl/upgrade.rs @@ -0,0 +1,72 @@ +use std::pin::Pin; +use std::task::{self, Poll}; +use std::{fmt, io}; + +use futures_util::TryFutureExt; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// An upgraded HTTP connection. +pub struct Upgraded { + inner: hyper::upgrade::Upgraded, +} + +impl AsyncRead for Upgraded { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +impl From for Upgraded { + fn from(inner: hyper::upgrade::Upgraded) -> Self { + Upgraded { inner } + } +} + +impl super::response::Response { + /// Consumes the response and returns a future for a possible HTTP upgrade. + pub fn upgrade(self) -> impl futures_core::Future> { + hyper::upgrade::on(self.res) + .map_ok(Upgraded::from) + .map_err(crate::error::upgrade) + } +} diff --git a/src/error.rs b/src/error.rs index fb73de97f..65b2e7b96 100644 --- a/src/error.rs +++ b/src/error.rs @@ -176,6 +176,7 @@ impl fmt::Display for Error { Kind::Body => f.write_str("request or response body error")?, Kind::Decode => f.write_str("error decoding response body")?, Kind::Redirect => f.write_str("error following redirect")?, + Kind::Upgrade => f.write_str("error upgrading connection")?, Kind::Status(ref code) => { let prefix = if code.is_client_error() { "HTTP status client error" @@ -225,6 +226,7 @@ pub(crate) enum Kind { Status(StatusCode), Body, Decode, + Upgrade, } // constructors @@ -263,6 +265,10 @@ if_wasm! { } } +pub(crate) fn upgrade>(e: E) -> Error { + Error::new(Kind::Upgrade, Some(e)) +} + // io::Error helpers #[allow(unused)] diff --git a/tests/upgrade.rs b/tests/upgrade.rs new file mode 100644 index 000000000..78b462529 --- /dev/null +++ b/tests/upgrade.rs @@ -0,0 +1,51 @@ +#![cfg(not(target_arch = "wasm32"))] +mod support; +use support::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +#[tokio::test] +async fn http_upgrade() { + let server = server::http(move |req| { + assert_eq!(req.method(), "GET"); + assert_eq!(req.headers()["connection"], "upgrade"); + assert_eq!(req.headers()["upgrade"], "foobar"); + + tokio::spawn(async move { + let mut upgraded = hyper::upgrade::on(req).await.unwrap(); + + let mut buf = vec![0; 7]; + upgraded.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"foo=bar"); + + upgraded.write_all(b"bar=foo").await.unwrap(); + }); + + async { + http::Response::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "foobar") + .body(hyper::Body::empty()) + .unwrap() + } + }); + + let res = reqwest::Client::builder() + .build() + .unwrap() + .get(format!("http://{}", server.addr())) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "foobar") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS); + let mut upgraded = res.upgrade().await.unwrap(); + + upgraded.write_all(b"foo=bar").await.unwrap(); + + let mut buf = vec![]; + upgraded.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"bar=foo"); +}