Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HTTP Upgrade support to Response. #1376

Merged
merged 5 commits into from Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/async_impl/mod.rs
Expand Up @@ -13,3 +13,4 @@ pub mod decoder;
pub mod multipart;
pub(crate) mod request;
mod response;
mod upgrade;
72 changes: 30 additions & 42 deletions src/async_impl/response.rs
Expand Up @@ -24,14 +24,10 @@ use crate::response::ResponseUrl;

/// A Response to a submitted `Request`.
pub struct Response {
status: StatusCode,
headers: HeaderMap,
pub(super) res: hyper::Response<Decoder>,
// Boxed to save space (11 words to 1 word), and it's not accessed
// frequently internally.
url: Box<Url>,
body: Decoder,
version: Version,
extensions: http::Extensions,
}

impl Response {
Expand All @@ -41,46 +37,38 @@ impl Response {
accepts: Accepts,
timeout: Option<Pin<Box<Sleep>>>,
) -> Response {
let (parts, body) = res.into_parts();
let status = parts.status;
let version = parts.version;
let extensions = parts.extensions;

let mut headers = parts.headers;
let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), accepts);
let (mut parts, body) = res.into_parts();
let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts);
let res = hyper::Response::from_parts(parts, decoder);

Response {
status,
headers,
res,
url: Box::new(url),
body: decoder,
version,
extensions,
}
}

/// Get the `StatusCode` of this `Response`.
#[inline]
pub fn status(&self) -> StatusCode {
self.status
self.res.status()
}

/// Get the HTTP `Version` of this `Response`.
#[inline]
pub fn version(&self) -> Version {
self.version
self.res.version()
}

/// Get the `Headers` of this `Response`.
#[inline]
pub fn headers(&self) -> &HeaderMap {
&self.headers
self.res.headers()
}

/// Get a mutable reference to the `Headers` of this `Response`.
#[inline]
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
self.res.headers_mut()
}

/// Get the content-length of this response, if known.
Expand All @@ -93,7 +81,7 @@ impl Response {
pub fn content_length(&self) -> Option<u64> {
use hyper::body::HttpBody;

HttpBody::size_hint(&self.body).exact()
HttpBody::size_hint(self.res.body()).exact()
}

/// Retrieve the cookies contained in the response.
Expand All @@ -106,7 +94,7 @@ impl Response {
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookies<'a>(&'a self) -> impl Iterator<Item = cookie::Cookie<'a>> + 'a {
cookie::extract_response_cookies(&self.headers).filter_map(Result::ok)
cookie::extract_response_cookies(self.res.headers()).filter_map(Result::ok)
}

/// Get the final `Url` of this `Response`.
Expand All @@ -117,19 +105,20 @@ impl Response {

/// Get the remote address used to get this `Response`.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.extensions
self.res
.extensions()
.get::<HttpInfo>()
.map(|info| info.remote_addr())
}

/// Returns a reference to the associated extensions.
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
self.res.extensions()
}

/// Returns a mutable reference to the associated extensions.
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
self.res.extensions_mut()
}

// body methods
Expand Down Expand Up @@ -183,7 +172,7 @@ impl Response {
/// ```
pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
let content_type = self
.headers
.headers()
.get(crate::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<Mime>().ok());
Expand Down Expand Up @@ -271,7 +260,7 @@ impl Response {
/// # }
/// ```
pub async fn bytes(self) -> crate::Result<Bytes> {
hyper::body::to_bytes(self.body).await
hyper::body::to_bytes(self.res.into_body()).await
}

/// Stream a chunk of the response body.
Expand All @@ -291,7 +280,7 @@ impl Response {
/// # }
/// ```
pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
if let Some(item) = self.body.next().await {
if let Some(item) = self.res.body_mut().next().await {
Ok(Some(item?))
} else {
Ok(None)
Expand Down Expand Up @@ -323,7 +312,7 @@ impl Response {
#[cfg(feature = "stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
pub fn bytes_stream(self) -> impl futures_core::Stream<Item = crate::Result<Bytes>> {
self.body
self.res.into_body()
}

// util methods
Expand All @@ -350,8 +339,9 @@ impl Response {
/// # fn main() {}
/// ```
pub fn error_for_status(self) -> crate::Result<Self> {
if self.status.is_client_error() || self.status.is_server_error() {
Err(crate::error::status_code(*self.url, self.status))
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(crate::error::status_code(*self.url, status))
} else {
Ok(self)
}
Expand Down Expand Up @@ -379,8 +369,9 @@ impl Response {
/// # fn main() {}
/// ```
pub fn error_for_status_ref(&self) -> crate::Result<&Self> {
if self.status.is_client_error() || self.status.is_server_error() {
Err(crate::error::status_code(*self.url.clone(), self.status))
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(crate::error::status_code(*self.url.clone(), status))
} else {
Ok(self)
}
Expand All @@ -395,7 +386,7 @@ impl Response {
// This method is just used by the blocking API.
#[cfg(feature = "blocking")]
pub(crate) fn body_mut(&mut self) -> &mut Decoder {
&mut self.body
self.res.body_mut()
}
}

Expand All @@ -413,27 +404,24 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
fn from(r: http::Response<T>) -> Response {
let (mut parts, body) = r.into_parts();
let body = body.into();
let body = Decoder::detect(&mut parts.headers, body, Accepts::none());
let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none());
let url = parts
.extensions
.remove::<ResponseUrl>()
.unwrap_or_else(|| ResponseUrl(Url::parse("http://no.url.provided.local").unwrap()));
let url = url.0;
let res = hyper::Response::from_parts(parts, decoder);
Response {
status: parts.status,
headers: parts.headers,
res,
url: Box::new(url),
body,
version: parts.version,
extensions: parts.extensions,
}
}
}

/// A `Response` can be piped as the `Body` of another request.
impl From<Response> for Body {
fn from(r: Response) -> Body {
Body::stream(r.body)
Body::stream(r.res.into_body())
}
}

Expand Down
73 changes: 73 additions & 0 deletions src/async_impl/upgrade.rs
@@ -0,0 +1,73 @@
use std::pin::Pin;
use std::task::{self, Poll};
use std::{fmt, io};

use futures_util::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

/// An upgraded HTTP connection.
pub struct Upgraded {
luqmana marked this conversation as resolved.
Show resolved Hide resolved
inner: hyper::upgrade::Upgraded,
}

impl AsyncRead for Upgraded {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl AsyncWrite for Upgraded {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl fmt::Debug for Upgraded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgraded").finish()
}
}

impl From<hyper::upgrade::Upgraded> for Upgraded {
fn from(inner: hyper::upgrade::Upgraded) -> Self {
Upgraded { inner }
}
}

impl super::response::Response {
/// Consumes the response and returns a future for a possible HTTP upgrade.
pub async fn upgrade(self) -> crate::Result<Upgraded> {
hyper::upgrade::on(self.res)
.map_ok(Upgraded::from)
.map_err(crate::error::upgrade)
.await
}
}
6 changes: 6 additions & 0 deletions src/error.rs
Expand Up @@ -185,6 +185,7 @@ impl fmt::Display for Error {
Kind::Body => f.write_str("request or response body error")?,
Kind::Decode => f.write_str("error decoding response body")?,
Kind::Redirect => f.write_str("error following redirect")?,
Kind::Upgrade => f.write_str("error upgrading connection")?,
Kind::Status(ref code) => {
let prefix = if code.is_client_error() {
"HTTP status client error"
Expand Down Expand Up @@ -236,6 +237,7 @@ pub(crate) enum Kind {
Status(StatusCode),
Body,
Decode,
Upgrade,
}

// constructors
Expand Down Expand Up @@ -274,6 +276,10 @@ if_wasm! {
}
}

pub(crate) fn upgrade<E: Into<BoxError>>(e: E) -> Error {
Error::new(Kind::Upgrade, Some(e))
}

// io::Error helpers

#[allow(unused)]
Expand Down
51 changes: 51 additions & 0 deletions tests/upgrade.rs
@@ -0,0 +1,51 @@
#![cfg(not(target_arch = "wasm32"))]
mod support;
use support::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[tokio::test]
async fn http_upgrade() {
let server = server::http(move |req| {
assert_eq!(req.method(), "GET");
assert_eq!(req.headers()["connection"], "upgrade");
assert_eq!(req.headers()["upgrade"], "foobar");

tokio::spawn(async move {
let mut upgraded = hyper::upgrade::on(req).await.unwrap();

let mut buf = vec![0; 7];
upgraded.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"foo=bar");

upgraded.write_all(b"bar=foo").await.unwrap();
});

async {
http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.body(hyper::Body::empty())
.unwrap()
}
});

let res = reqwest::Client::builder()
.build()
.unwrap()
.get(format!("http://{}", server.addr()))
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.send()
.await
.unwrap();

assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS);
let mut upgraded = res.upgrade().await.unwrap();

upgraded.write_all(b"foo=bar").await.unwrap();

let mut buf = vec![];
upgraded.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"bar=foo");
}