diff --git a/src/async_impl/mod.rs b/src/async_impl/mod.rs index b8bc8aa6f..6fc29fc29 100644 --- a/src/async_impl/mod.rs +++ b/src/async_impl/mod.rs @@ -13,3 +13,4 @@ pub mod decoder; pub mod multipart; pub(crate) mod request; mod response; +mod upgrade; diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 33d859ed4..09295af70 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -24,14 +24,10 @@ use crate::response::ResponseUrl; /// A Response to a submitted `Request`. pub struct Response { - status: StatusCode, - headers: HeaderMap, + pub(super) res: hyper::Response, // Boxed to save space (11 words to 1 word), and it's not accessed // frequently internally. url: Box, - body: Decoder, - version: Version, - extensions: http::Extensions, } impl Response { @@ -41,46 +37,38 @@ impl Response { accepts: Accepts, timeout: Option>>, ) -> Response { - let (parts, body) = res.into_parts(); - let status = parts.status; - let version = parts.version; - let extensions = parts.extensions; - - let mut headers = parts.headers; - let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), accepts); + let (mut parts, body) = res.into_parts(); + let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts); + let res = hyper::Response::from_parts(parts, decoder); Response { - status, - headers, + res, url: Box::new(url), - body: decoder, - version, - extensions, } } /// Get the `StatusCode` of this `Response`. #[inline] pub fn status(&self) -> StatusCode { - self.status + self.res.status() } /// Get the HTTP `Version` of this `Response`. #[inline] pub fn version(&self) -> Version { - self.version + self.res.version() } /// Get the `Headers` of this `Response`. #[inline] pub fn headers(&self) -> &HeaderMap { - &self.headers + self.res.headers() } /// Get a mutable reference to the `Headers` of this `Response`. #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers + self.res.headers_mut() } /// Get the content-length of this response, if known. @@ -93,7 +81,7 @@ impl Response { pub fn content_length(&self) -> Option { use hyper::body::HttpBody; - HttpBody::size_hint(&self.body).exact() + HttpBody::size_hint(self.res.body()).exact() } /// Retrieve the cookies contained in the response. @@ -106,7 +94,7 @@ impl Response { #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] pub fn cookies<'a>(&'a self) -> impl Iterator> + 'a { - cookie::extract_response_cookies(&self.headers).filter_map(Result::ok) + cookie::extract_response_cookies(self.res.headers()).filter_map(Result::ok) } /// Get the final `Url` of this `Response`. @@ -117,19 +105,20 @@ impl Response { /// Get the remote address used to get this `Response`. pub fn remote_addr(&self) -> Option { - self.extensions + self.res + .extensions() .get::() .map(|info| info.remote_addr()) } /// Returns a reference to the associated extensions. pub fn extensions(&self) -> &http::Extensions { - &self.extensions + self.res.extensions() } /// Returns a mutable reference to the associated extensions. pub fn extensions_mut(&mut self) -> &mut http::Extensions { - &mut self.extensions + self.res.extensions_mut() } // body methods @@ -183,7 +172,7 @@ impl Response { /// ``` pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result { let content_type = self - .headers + .headers() .get(crate::header::CONTENT_TYPE) .and_then(|value| value.to_str().ok()) .and_then(|value| value.parse::().ok()); @@ -271,7 +260,7 @@ impl Response { /// # } /// ``` pub async fn bytes(self) -> crate::Result { - hyper::body::to_bytes(self.body).await + hyper::body::to_bytes(self.res.into_body()).await } /// Stream a chunk of the response body. @@ -291,7 +280,7 @@ impl Response { /// # } /// ``` pub async fn chunk(&mut self) -> crate::Result> { - if let Some(item) = self.body.next().await { + if let Some(item) = self.res.body_mut().next().await { Ok(Some(item?)) } else { Ok(None) @@ -323,7 +312,7 @@ impl Response { #[cfg(feature = "stream")] #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] pub fn bytes_stream(self) -> impl futures_core::Stream> { - self.body + self.res.into_body() } // util methods @@ -350,8 +339,9 @@ impl Response { /// # fn main() {} /// ``` pub fn error_for_status(self) -> crate::Result { - if self.status.is_client_error() || self.status.is_server_error() { - Err(crate::error::status_code(*self.url, self.status)) + let status = self.status(); + if status.is_client_error() || status.is_server_error() { + Err(crate::error::status_code(*self.url, status)) } else { Ok(self) } @@ -379,8 +369,9 @@ impl Response { /// # fn main() {} /// ``` pub fn error_for_status_ref(&self) -> crate::Result<&Self> { - if self.status.is_client_error() || self.status.is_server_error() { - Err(crate::error::status_code(*self.url.clone(), self.status)) + let status = self.status(); + if status.is_client_error() || status.is_server_error() { + Err(crate::error::status_code(*self.url.clone(), status)) } else { Ok(self) } @@ -395,7 +386,7 @@ impl Response { // This method is just used by the blocking API. #[cfg(feature = "blocking")] pub(crate) fn body_mut(&mut self) -> &mut Decoder { - &mut self.body + self.res.body_mut() } } @@ -413,19 +404,16 @@ impl> From> for Response { fn from(r: http::Response) -> Response { let (mut parts, body) = r.into_parts(); let body = body.into(); - let body = Decoder::detect(&mut parts.headers, body, Accepts::none()); + let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none()); let url = parts .extensions .remove::() .unwrap_or_else(|| ResponseUrl(Url::parse("http://no.url.provided.local").unwrap())); let url = url.0; + let res = hyper::Response::from_parts(parts, decoder); Response { - status: parts.status, - headers: parts.headers, + res, url: Box::new(url), - body, - version: parts.version, - extensions: parts.extensions, } } } @@ -433,7 +421,7 @@ impl> From> for Response { /// A `Response` can be piped as the `Body` of another request. impl From for Body { fn from(r: Response) -> Body { - Body::stream(r.body) + Body::stream(r.res.into_body()) } } diff --git a/src/async_impl/upgrade.rs b/src/async_impl/upgrade.rs new file mode 100644 index 000000000..4a69b4db5 --- /dev/null +++ b/src/async_impl/upgrade.rs @@ -0,0 +1,73 @@ +use std::pin::Pin; +use std::task::{self, Poll}; +use std::{fmt, io}; + +use futures_util::TryFutureExt; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// An upgraded HTTP connection. +pub struct Upgraded { + inner: hyper::upgrade::Upgraded, +} + +impl AsyncRead for Upgraded { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgraded").finish() + } +} + +impl From for Upgraded { + fn from(inner: hyper::upgrade::Upgraded) -> Self { + Upgraded { inner } + } +} + +impl super::response::Response { + /// Consumes the response and returns a future for a possible HTTP upgrade. + pub async fn upgrade(self) -> crate::Result { + hyper::upgrade::on(self.res) + .map_ok(Upgraded::from) + .map_err(crate::error::upgrade) + .await + } +} diff --git a/src/error.rs b/src/error.rs index 3f829d99a..0e6bd247d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -185,6 +185,7 @@ impl fmt::Display for Error { Kind::Body => f.write_str("request or response body error")?, Kind::Decode => f.write_str("error decoding response body")?, Kind::Redirect => f.write_str("error following redirect")?, + Kind::Upgrade => f.write_str("error upgrading connection")?, Kind::Status(ref code) => { let prefix = if code.is_client_error() { "HTTP status client error" @@ -236,6 +237,7 @@ pub(crate) enum Kind { Status(StatusCode), Body, Decode, + Upgrade, } // constructors @@ -274,6 +276,10 @@ if_wasm! { } } +pub(crate) fn upgrade>(e: E) -> Error { + Error::new(Kind::Upgrade, Some(e)) +} + // io::Error helpers #[allow(unused)] diff --git a/tests/upgrade.rs b/tests/upgrade.rs new file mode 100644 index 000000000..78b462529 --- /dev/null +++ b/tests/upgrade.rs @@ -0,0 +1,51 @@ +#![cfg(not(target_arch = "wasm32"))] +mod support; +use support::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +#[tokio::test] +async fn http_upgrade() { + let server = server::http(move |req| { + assert_eq!(req.method(), "GET"); + assert_eq!(req.headers()["connection"], "upgrade"); + assert_eq!(req.headers()["upgrade"], "foobar"); + + tokio::spawn(async move { + let mut upgraded = hyper::upgrade::on(req).await.unwrap(); + + let mut buf = vec![0; 7]; + upgraded.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"foo=bar"); + + upgraded.write_all(b"bar=foo").await.unwrap(); + }); + + async { + http::Response::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "foobar") + .body(hyper::Body::empty()) + .unwrap() + } + }); + + let res = reqwest::Client::builder() + .build() + .unwrap() + .get(format!("http://{}", server.addr())) + .header(http::header::CONNECTION, "upgrade") + .header(http::header::UPGRADE, "foobar") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS); + let mut upgraded = res.upgrade().await.unwrap(); + + upgraded.write_all(b"foo=bar").await.unwrap(); + + let mut buf = vec![]; + upgraded.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"bar=foo"); +}