Skip to content

Commit

Permalink
Expose upgrade method on Response to reuse underlying connection afte…
Browse files Browse the repository at this point in the history
…r HTTP Upgrade.
  • Loading branch information
luqmana committed Nov 3, 2021
1 parent f1383d5 commit 05202a2
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Cargo.toml
Expand Up @@ -230,3 +230,7 @@ required-features = ["deflate"]
name = "multipart"
path = "tests/multipart.rs"
required-features = ["multipart"]

# TODO: Remove once a new version w/ https://github.com/hyperium/hyper/pull/2680 gets released
[patch.crates-io]
hyper = { git = "https://github.com/hyperium/hyper.git", branch = "master" }
1 change: 1 addition & 0 deletions src/async_impl/mod.rs
Expand Up @@ -13,3 +13,4 @@ pub mod decoder;
pub mod multipart;
pub(crate) mod request;
mod response;
mod upgrade;
2 changes: 1 addition & 1 deletion src/async_impl/response.rs
Expand Up @@ -24,7 +24,7 @@ use crate::response::ResponseUrl;

/// A Response to a submitted `Request`.
pub struct Response {
res: hyper::Response<Decoder>,
pub(super) res: hyper::Response<Decoder>,
// Boxed to save space (11 words to 1 word), and it's not accessed
// frequently internally.
url: Box<Url>,
Expand Down
72 changes: 72 additions & 0 deletions src/async_impl/upgrade.rs
@@ -0,0 +1,72 @@
use std::pin::Pin;
use std::task::{self, Poll};
use std::{fmt, io};

use futures_util::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

/// An upgraded HTTP connection.
pub struct Upgraded {
inner: hyper::upgrade::Upgraded,
}

impl AsyncRead for Upgraded {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl AsyncWrite for Upgraded {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl fmt::Debug for Upgraded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgraded").finish()
}
}

impl From<hyper::upgrade::Upgraded> for Upgraded {
fn from(inner: hyper::upgrade::Upgraded) -> Self {
Upgraded { inner }
}
}

impl super::response::Response {
/// Consumes the response and returns a future for a possible HTTP upgrade.
pub fn upgrade(self) -> impl futures_core::Future<Output = crate::Result<Upgraded>> {
hyper::upgrade::on(self.res)
.map_ok(Upgraded::from)
.map_err(crate::error::upgrade)
}
}
6 changes: 6 additions & 0 deletions src/error.rs
Expand Up @@ -176,6 +176,7 @@ impl fmt::Display for Error {
Kind::Body => f.write_str("request or response body error")?,
Kind::Decode => f.write_str("error decoding response body")?,
Kind::Redirect => f.write_str("error following redirect")?,
Kind::Upgrade => f.write_str("error upgrading connection")?,
Kind::Status(ref code) => {
let prefix = if code.is_client_error() {
"HTTP status client error"
Expand Down Expand Up @@ -225,6 +226,7 @@ pub(crate) enum Kind {
Status(StatusCode),
Body,
Decode,
Upgrade,
}

// constructors
Expand Down Expand Up @@ -263,6 +265,10 @@ if_wasm! {
}
}

pub(crate) fn upgrade<E: Into<BoxError>>(e: E) -> Error {
Error::new(Kind::Upgrade, Some(e))
}

// io::Error helpers

#[allow(unused)]
Expand Down
51 changes: 51 additions & 0 deletions tests/upgrade.rs
@@ -0,0 +1,51 @@
#![cfg(not(target_arch = "wasm32"))]
mod support;
use support::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[tokio::test]
async fn http_upgrade() {
let server = server::http(move |req| {
assert_eq!(req.method(), "GET");
assert_eq!(req.headers()["connection"], "upgrade");
assert_eq!(req.headers()["upgrade"], "foobar");

tokio::spawn(async move {
let mut upgraded = hyper::upgrade::on(req).await.unwrap();

let mut buf = vec![0; 7];
upgraded.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"foo=bar");

upgraded.write_all(b"bar=foo").await.unwrap();
});

async {
http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.body(hyper::Body::empty())
.unwrap()
}
});

let res = reqwest::Client::builder()
.build()
.unwrap()
.get(format!("http://{}", server.addr()))
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.send()
.await
.unwrap();

assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS);
let mut upgraded = res.upgrade().await.unwrap();

upgraded.write_all(b"foo=bar").await.unwrap();

let mut buf = vec![];
upgraded.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"bar=foo");
}

0 comments on commit 05202a2

Please sign in to comment.