diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index da54f4741..cc0807888 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -191,7 +191,7 @@ jobs: strategy: matrix: - rust: [1.39.0] + rust: [1.45.2] steps: - name: Checkout diff --git a/Cargo.toml b/Cargo.toml index 4efc1e15d..5bc1686b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ default = ["default-tls"] # Note: this doesn't enable the 'native-tls' feature, which adds specific # functionality for it. -default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-tls"] +default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-native-tls"] # Enables native-tls specific functionality not available by default. native-tls = ["default-tls"] @@ -39,13 +39,13 @@ rustls-tls-manual-roots = ["__rustls"] rustls-tls-webpki-roots = ["webpki-roots", "__rustls"] rustls-tls-native-roots = ["rustls-native-certs", "__rustls"] -blocking = ["futures-util/io", "tokio/rt-threaded", "tokio/rt-core", "tokio/sync"] +blocking = ["futures-util/io", "tokio/rt-multi-thread", "tokio/sync"] cookies = ["cookie_crate", "cookie_store", "time"] -gzip = ["async-compression", "async-compression/gzip"] +gzip = ["async-compression", "async-compression/gzip", "tokio-util"] -brotli = ["async-compression", "async-compression/brotli"] +brotli = ["async-compression", "async-compression/brotli", "tokio-util"] json = ["serde_json"] @@ -71,7 +71,7 @@ __internal_proxy_sys_no_cache = [] [dependencies] http = "0.2" url = "2.2" -bytes = "0.5" +bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7" mime_guess = "2.0" @@ -83,29 +83,29 @@ base64 = "0.13" encoding_rs = "0.8" futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } -http-body = "0.3.0" -hyper = { version = "0.13.4", default-features = false, features = ["tcp"] } +http-body = "0.4.0" +hyper = { version = "0.14", default-features = false, features = ["tcp", "http1", "http2", "client"] } lazy_static = "1.4" log = "0.4" mime = "0.3.7" percent-encoding = "2.1" -tokio = { version = "0.2.5", default-features = false, features = ["tcp", "time"] } +tokio = { version = "1.0", default-features = false, features = ["net", "time"] } pin-project-lite = "0.2.0" ipnet = "2.3" # Optional deps... ## default-tls -hyper-tls = { version = "0.4", optional = true } +hyper-tls = { version = "0.5", optional = true } native-tls-crate = { version = "0.2", optional = true, package = "native-tls" } -tokio-tls = { version = "0.3.0", optional = true } +tokio-native-tls = { version = "0.3.0", optional = true } # rustls-tls -hyper-rustls = { version = "0.21", default-features = false, optional = true } -rustls = { version = "0.18", features = ["dangerous_configuration"], optional = true } -tokio-rustls = { version = "0.14", optional = true } -webpki-roots = { version = "0.20", optional = true } -rustls-native-certs = { version = "0.4", optional = true } +hyper-rustls = { version = "0.22.1", default-features = false, optional = true } +rustls = { version = "0.19", features = ["dangerous_configuration"], optional = true } +tokio-rustls = { version = "0.22", optional = true } +webpki-roots = { version = "0.21", optional = true } +rustls-native-certs = { version = "0.5", optional = true } ## cookies cookie_crate = { version = "0.14", package = "cookie", optional = true } @@ -113,23 +113,23 @@ cookie_store = { version = "0.12", optional = true } time = { version = "0.2.11", optional = true } ## compression -async-compression = { version = "0.3.0", default-features = false, features = ["stream"], optional = true } - +async-compression = { version = "0.3.7", default-features = false, features = ["tokio"], optional = true } +tokio-util = { version = "0.6.0", default-features = false, features = ["codec", "io"], optional = true } ## socks -tokio-socks = { version = "0.3", optional = true } +tokio-socks = { version = "0.5", optional = true } ## trust-dns -trust-dns-resolver = { version = "0.19", optional = true } +trust-dns-resolver = { version = "0.20", optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -env_logger = "0.7" -hyper = { version = "0.13", default-features = false, features = ["tcp", "stream"] } +env_logger = "0.8" +hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server"] } serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } doc-comment = "0.3" -tokio = { version = "0.2.0", default-features = false, features = ["macros"] } +tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] } [target.'cfg(windows)'.dependencies] winreg = "0.7" diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index 762d70a6c..2c1616669 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -7,7 +7,7 @@ use bytes::Bytes; use futures_core::Stream; use http_body::Body as HttpBody; use pin_project_lite::pin_project; -use tokio::time::Delay; +use tokio::time::Sleep; /// An asynchronous request body. pub struct Body { @@ -27,7 +27,7 @@ enum Inner { + Sync, >, >, - timeout: Option, + timeout: Option>>, }, } @@ -103,7 +103,7 @@ impl Body { } } - pub(crate) fn response(body: hyper::Body, timeout: Option) -> Body { + pub(crate) fn response(body: hyper::Body, timeout: Option>>) -> Body { Body { inner: Inner::Streaming { body: Box::pin(WrapHyper(body)), @@ -217,7 +217,7 @@ impl HttpBody for ImplStream { ref mut timeout, } => { if let Some(ref mut timeout) = timeout { - if let Poll::Ready(()) = Pin::new(timeout).poll(cx) { + if let Poll::Ready(()) = timeout.as_mut().poll(cx) { return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); } } diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 0b842f0e8..089564dee 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -26,7 +26,7 @@ use rustls::RootCertStore; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::time::Delay; +use tokio::time::Sleep; use pin_project_lite::pin_project; use log::debug; @@ -96,7 +96,6 @@ struct Config { #[cfg(feature = "__tls")] tls: TlsBackend, http2_only: bool, - http1_writev: Option, http1_title_case_headers: bool, http2_initial_stream_window_size: Option, http2_initial_connection_window_size: Option, @@ -151,7 +150,6 @@ impl ClientBuilder { #[cfg(feature = "__tls")] tls: TlsBackend::default(), http2_only: false, - http1_writev: None, http1_title_case_headers: false, http2_initial_stream_window_size: None, http2_initial_connection_window_size: None, @@ -316,10 +314,6 @@ impl ClientBuilder { builder.http2_only(true); } - if let Some(http1_writev) = config.http1_writev { - builder.http1_writev(http1_writev); - } - if let Some(http2_initial_stream_window_size) = config.http2_initial_stream_window_size { builder.http2_initial_stream_window_size(http2_initial_stream_window_size); } @@ -655,14 +649,6 @@ impl ClientBuilder { self } - /// Force hyper to use either queued(if true), or flattened(if false) write strategy - /// This may eliminate unnecessary cloning of buffers for some TLS backends - /// By default hyper will try to guess which strategy to use - pub fn http1_writev(mut self, writev: bool) -> ClientBuilder { - self.config.http1_writev = Some(writev); - self - } - /// Only use HTTP/2. pub fn http2_prior_knowledge(mut self) -> ClientBuilder { self.config.http2_only = true; @@ -1103,7 +1089,8 @@ impl Client { let timeout = timeout .or(self.inner.request_timeout) - .map(tokio::time::delay_for); + .map(tokio::time::sleep) + .map(Box::pin); *req.headers_mut() = headers.clone(); @@ -1317,7 +1304,7 @@ pin_project! { #[pin] in_flight: ResponseFuture, #[pin] - timeout: Option, + timeout: Option>>, } } @@ -1326,7 +1313,7 @@ impl PendingRequest { self.project().in_flight } - fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option> { + fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option>>> { self.project().timeout } diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index e5c9c3f31..26def2050 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -4,10 +4,10 @@ use std::pin::Pin; use std::task::{Context, Poll}; #[cfg(feature = "gzip")] -use async_compression::stream::GzipDecoder; +use async_compression::tokio::bufread::GzipDecoder; #[cfg(feature = "brotli")] -use async_compression::stream::BrotliDecoder; +use async_compression::tokio::bufread::BrotliDecoder; use bytes::Bytes; use futures_core::Stream; @@ -15,6 +15,11 @@ use futures_util::stream::Peekable; use http::HeaderMap; use hyper::body::HttpBody; +#[cfg(any(feature = "gzip", feature = "brotli"))] +use tokio_util::io::StreamReader; +#[cfg(any(feature = "gzip", feature = "brotli"))] +use tokio_util::codec::{BytesCodec, FramedRead}; + use super::super::Body; use crate::error; @@ -39,11 +44,11 @@ enum Inner { /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] - Gzip(GzipDecoder>), + Gzip(FramedRead, Bytes>>, BytesCodec>), /// A `Brotli` decoder will uncompress the brotlied response content before returning it. #[cfg(feature = "brotli")] - Brotli(BrotliDecoder>), + Brotli(FramedRead, Bytes>>, BytesCodec>), /// A decoder that doesn't have a value yet. #[cfg(any(feature = "brotli", feature = "gzip"))] @@ -229,7 +234,7 @@ impl Stream for Decoder { #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { return match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), }; @@ -237,7 +242,7 @@ impl Stream for Decoder { #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { return match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))), + Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), None => Poll::Ready(None), }; @@ -302,9 +307,9 @@ impl Future for Pending { match self.1 { #[cfg(feature = "brotli")] - DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(BrotliDecoder::new(_body)))), + DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::new(BrotliDecoder::new(StreamReader::new(_body)), BytesCodec::new())))), #[cfg(feature = "gzip")] - DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(GzipDecoder::new(_body)))), + DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::new(GzipDecoder::new(StreamReader::new(_body)), BytesCodec::new())))), } } } diff --git a/src/async_impl/multipart.rs b/src/async_impl/multipart.rs index 526c68b95..62b56f478 100644 --- a/src/async_impl/multipart.rs +++ b/src/async_impl/multipart.rs @@ -521,11 +521,7 @@ mod tests { fn form_empty() { let form = Form::new(); - let mut rt = runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let body = form.stream().into_stream(); let s = body.map_ok(|try_c| try_c.to_vec()).try_concat(); @@ -572,11 +568,7 @@ mod tests { --boundary\r\n\ Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\ value3\r\n--boundary--\r\n"; - let mut rt = runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let body = form.stream().into_stream(); let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat(); @@ -603,11 +595,7 @@ mod tests { \r\n\ value2\r\n\ --boundary--\r\n"; - let mut rt = runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - .expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let body = form.stream().into_stream(); let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat(); diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index ef7e74afd..2035f01e4 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::fmt; use std::net::SocketAddr; +use std::pin::Pin; use bytes::Bytes; use encoding_rs::{Encoding, UTF_8}; @@ -12,7 +13,7 @@ use mime::Mime; use serde::de::DeserializeOwned; #[cfg(feature = "json")] use serde_json; -use tokio::time::Delay; +use tokio::time::Sleep; use url::Url; use super::body::Body; @@ -37,7 +38,7 @@ impl Response { res: hyper::Response, url: Url, accepts: Accepts, - timeout: Option, + timeout: Option>>, ) -> Response { let (parts, body) = res.into_parts(); let status = parts.status; diff --git a/src/blocking/body.rs b/src/blocking/body.rs index 3cb0d1cc6..c42e418e5 100644 --- a/src/blocking/body.rs +++ b/src/blocking/body.rs @@ -2,10 +2,11 @@ use std::fmt; use std::fs::File; use std::future::Future; use std::io::{self, Cursor, Read}; -use std::mem::{self, MaybeUninit}; +use std::mem; use std::ptr; use bytes::Bytes; +use bytes::buf::UninitSlice; use crate::async_impl; @@ -289,14 +290,14 @@ async fn send_future(sender: Sender) -> Result<(), crate::Error> { if buf.remaining_mut() == 0 { buf.reserve(8192); // zero out the reserved memory + let uninit = buf.chunk_mut(); unsafe { - let uninit = mem::transmute::<&mut [MaybeUninit], &mut [u8]>(buf.bytes_mut()); ptr::write_bytes(uninit.as_mut_ptr(), 0, uninit.len()); } } let bytes = unsafe { - mem::transmute::<&mut [MaybeUninit], &mut [u8]>(buf.bytes_mut()) + mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut()) }; match body.read(bytes) { Ok(0) => { diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 66c74dcaa..dbc16c826 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -764,7 +764,7 @@ impl ClientHandle { .name("reqwest-internal-sync-runtime".into()) .spawn(move || { use tokio::runtime; - let mut rt = match runtime::Builder::new().basic_scheduler().enable_all().build().map_err(crate::error::builder) { + let rt = match runtime::Builder::new_current_thread().enable_all().build().map_err(crate::error::builder) { Err(e) => { if let Err(e) = spawn_tx.send(Err(e)) { error!("Failed to communicate runtime creation failure: {:?}", e); diff --git a/src/blocking/wait.rs b/src/blocking/wait.rs index cfca8da9d..801f1678d 100644 --- a/src/blocking/wait.rs +++ b/src/blocking/wait.rs @@ -67,10 +67,9 @@ fn enter() { // Check we aren't already in a runtime #[cfg(debug_assertions)] { - tokio::runtime::Builder::new() - .core_threads(1) + tokio::runtime::Builder::new_current_thread() .build() .expect("build shell runtime") - .enter(|| {}); + .enter(); } } diff --git a/src/connect.rs b/src/connect.rs index 3061a16ff..c58ae8407 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -2,22 +2,21 @@ use hyper::service::Service; use http::uri::{Scheme, Authority}; use http::Uri; use hyper::client::connect::{Connected, Connection}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "native-tls-crate")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; #[cfg(feature = "__tls")] use http::header::HeaderValue; use futures_util::future::Either; -use bytes::{Buf, BufMut}; use std::future::Future; use std::io; +use std::io::IoSlice; use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use std::mem::MaybeUninit; use pin_project_lite::pin_project; #[cfg(feature = "trust-dns")] @@ -272,7 +271,7 @@ impl Connector { .ok_or("no host in url")? .to_string(); let conn = socks::connect(proxy, dst, dns).await?; - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector .connect(&host, conn) .await?; @@ -342,13 +341,13 @@ impl Connector { http.set_nodelay(true); } - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); let io = http.call(dst).await?; if let hyper_tls::MaybeHttpsStream::Https(stream) = &io { if !self.nodelay { - stream.get_ref().set_nodelay(false)?; + stream.get_ref().get_ref().get_ref().set_nodelay(false)?; } } @@ -411,7 +410,7 @@ impl Connector { let host = dst.host().to_owned(); let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); let http = http.clone(); - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); let conn = http.call(proxy_dst).await?; log::trace!("tunneling HTTPS over proxy"); @@ -424,7 +423,7 @@ impl Connector { self.user_agent.clone(), auth ).await?; - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); let io = tls_connector .connect(&host.ok_or("no host in url")?, tunneled) .await?; @@ -569,30 +568,11 @@ impl AsyncRead for Conn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8] - ) -> Poll> { + buf: &mut ReadBuf<'_> + ) -> Poll> { let this = self.project(); AsyncRead::poll_read(this.inner, cx, buf) } - - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [MaybeUninit] - ) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> - where - Self: Sized - { - let this = self.project(); - AsyncRead::poll_read_buf(this.inner, cx, buf) - } } impl AsyncWrite for Conn { @@ -605,6 +585,19 @@ impl AsyncWrite for Conn { AsyncWrite::poll_write(this.inner, cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>] + ) -> Poll> { + let this = self.project(); + AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); AsyncWrite::poll_flush(this.inner, cx) @@ -617,16 +610,6 @@ impl AsyncWrite for Conn { let this = self.project(); AsyncWrite::poll_shutdown(this.inner, cx) } - - fn poll_write_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> where - Self: Sized { - let this = self.project(); - AsyncWrite::poll_write_buf(this.inner, cx, buf) - } } pub(crate) type Connecting = @@ -715,13 +698,11 @@ fn tunnel_eof() -> BoxError { #[cfg(feature = "default-tls")] mod native_tls_conn { - use std::mem::MaybeUninit; - use std::{pin::Pin, task::{Context, Poll}}; - use bytes::{Buf, BufMut}; + use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}}; use hyper::client::connect::{Connected, Connection}; use pin_project_lite::pin_project; - use tokio::io::{AsyncRead, AsyncWrite}; - use tokio_tls::TlsStream; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_native_tls::TlsStream; pin_project! { @@ -732,7 +713,7 @@ mod native_tls_conn { impl Connection for NativeTlsConn { fn connected(&self) -> Connected { - self.inner.get_ref().connected() + self.inner.get_ref().get_ref().get_ref().connected() } } @@ -740,30 +721,11 @@ mod native_tls_conn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8] - ) -> Poll> { + buf: &mut ReadBuf<'_> + ) -> Poll> { let this = self.project(); AsyncRead::poll_read(this.inner, cx, buf) } - - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [MaybeUninit] - ) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> - where - Self: Sized - { - let this = self.project(); - AsyncRead::poll_read_buf(this.inner, cx, buf) - } } impl AsyncWrite for NativeTlsConn { @@ -776,6 +738,19 @@ mod native_tls_conn { AsyncWrite::poll_write(this.inner, cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>] + ) -> Poll> { + let this = self.project(); + AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); AsyncWrite::poll_flush(this.inner, cx) @@ -788,28 +763,16 @@ mod native_tls_conn { let this = self.project(); AsyncWrite::poll_shutdown(this.inner, cx) } - - fn poll_write_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> where - Self: Sized { - let this = self.project(); - AsyncWrite::poll_write_buf(this.inner, cx, buf) - } } } #[cfg(feature = "__rustls")] mod rustls_tls_conn { use rustls::Session; - use std::mem::MaybeUninit; - use std::{pin::Pin, task::{Context, Poll}}; - use bytes::{Buf, BufMut}; + use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}}; use hyper::client::connect::{Connected, Connection}; use pin_project_lite::pin_project; - use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::client::TlsStream; @@ -833,30 +796,11 @@ mod rustls_tls_conn { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8] - ) -> Poll> { + buf: &mut ReadBuf<'_> + ) -> Poll> { let this = self.project(); AsyncRead::poll_read(this.inner, cx, buf) } - - unsafe fn prepare_uninitialized_buffer( - &self, - buf: &mut [MaybeUninit] - ) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> - where - Self: Sized - { - let this = self.project(); - AsyncRead::poll_read_buf(this.inner, cx, buf) - } } impl AsyncWrite for RustlsTlsConn { @@ -869,6 +813,19 @@ mod rustls_tls_conn { AsyncWrite::poll_write(this.inner, cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>] + ) -> Poll> { + let this = self.project(); + AsyncWrite::poll_write_vectored(this.inner, cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); AsyncWrite::poll_flush(this.inner, cx) @@ -881,16 +838,6 @@ mod rustls_tls_conn { let this = self.project(); AsyncWrite::poll_shutdown(this.inner, cx) } - - fn poll_write_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B - ) -> Poll> where - Self: Sized { - let this = self.project(); - AsyncWrite::poll_write_buf(this.inner, cx, buf) - } } } @@ -961,10 +908,11 @@ mod socks { mod verbose { use std::fmt; + use std::io::{self, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; use hyper::client::connect::{Connected, Connection}; - use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub(super) const OFF: Wrapper = Wrapper(false); @@ -1000,12 +948,12 @@ mod verbose { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8] - ) -> Poll> { + buf: &mut ReadBuf<'_> + ) -> Poll> { match Pin::new(&mut self.inner).poll_read(cx, buf) { - Poll::Ready(Ok(n)) => { - log::trace!("{:08x} read: {:?}", self.id, Escape(&buf[..n])); - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) => { + log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled())); + Poll::Ready(Ok(())) }, Poll::Ready(Err(e)) => { Poll::Ready(Err(e)) @@ -1033,6 +981,18 @@ mod verbose { } } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>] + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx) } @@ -1137,7 +1097,7 @@ mod tests { fn test_tunnel() { let addr = mock_tunnel!(); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let f = async move { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); @@ -1152,7 +1112,7 @@ mod tests { fn test_tunnel_eof() { let addr = mock_tunnel!(b"HTTP/1.1 200 OK"); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let f = async move { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); @@ -1167,7 +1127,7 @@ mod tests { fn test_tunnel_non_http_response() { let addr = mock_tunnel!(b"foo bar baz hallo"); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let f = async move { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); @@ -1188,7 +1148,7 @@ mod tests { " ); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let f = async move { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); @@ -1207,7 +1167,7 @@ mod tests { "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n" ); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt"); let f = async move { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); diff --git a/src/dns.rs b/src/dns.rs index e533da115..adba67e65 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{self, Poll}; use std::io; +use std::net::SocketAddr; use hyper::client::connect::dns as hyper_dns; use hyper::service::Service; @@ -10,7 +11,7 @@ use tokio::sync::Mutex; use trust_dns_resolver::{ config::{ResolverConfig, ResolverOpts}, lookup_ip::LookupIpIntoIter, - system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider, + system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider, TokioHandle }; use crate::error::BoxError; @@ -26,6 +27,10 @@ pub(crate) struct TrustDnsResolver { state: Arc>, } +pub(crate) struct SocketAddrs { + iter: LookupIpIntoIter, +} + enum State { Init, Ready(SharedResolver), @@ -47,7 +52,7 @@ impl TrustDnsResolver { } impl Service for TrustDnsResolver { - type Response = LookupIpIntoIter; + type Response = SocketAddrs; type Error = BoxError; type Future = Pin> + Send>>; @@ -62,7 +67,7 @@ impl Service for TrustDnsResolver { let resolver = match &*lock { State::Init => { - let resolver = new_resolver(tokio::runtime::Handle::current()).await?; + let resolver = new_resolver().await?; *lock = State::Ready(resolver.clone()); resolver }, @@ -74,18 +79,24 @@ impl Service for TrustDnsResolver { drop(lock); let lookup = resolver.lookup_ip(name.as_str()).await?; - Ok(lookup.into_iter()) + Ok(SocketAddrs { iter: lookup.into_iter() }) }) } } -/// Takes a `Handle` argument as an indicator that it must be called from -/// within the context of a Tokio runtime. -async fn new_resolver(handle: tokio::runtime::Handle) -> Result { +impl Iterator for SocketAddrs { + type Item = SocketAddr; + + fn next(&mut self) -> Option { + self.iter.next().map(|ip_addr| SocketAddr::new(ip_addr, 0)) + } +} + +async fn new_resolver() -> Result { let (config, opts) = SYSTEM_CONF .as_ref() .expect("can't construct TrustDnsResolver if SYSTEM_CONF is error") .clone(); - let resolver = AsyncResolver::new(config, opts, handle).await?; + let resolver = AsyncResolver::new(config, opts, TokioHandle)?; Ok(Arc::new(resolver)) } diff --git a/tests/blocking.rs b/tests/blocking.rs index c98a2c319..93f204ac8 100644 --- a/tests/blocking.rs +++ b/tests/blocking.rs @@ -282,7 +282,9 @@ fn test_blocking_inside_a_runtime() { let url = format!("http://{}/text", server.addr()); - let mut rt = tokio::runtime::Builder::new().build().expect("new rt"); + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("new rt"); rt.block_on(async move { let _should_panic = reqwest::blocking::get(&url); diff --git a/tests/redirect.rs b/tests/redirect.rs index c1621ca54..16f7712f5 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -155,14 +155,15 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() { async fn test_redirect_removes_sensitive_headers() { use tokio::sync::watch; - let (tx, rx) = watch::channel(None); + let (tx, rx) = watch::channel::>(None); let end_server = server::http(move |req| { let mut rx = rx.clone(); async move { assert_eq!(req.headers().get("cookie"), None); - let mid_addr = rx.recv().await.unwrap().unwrap(); + rx.changed().await.unwrap(); + let mid_addr = rx.borrow().unwrap(); assert_eq!( req.headers()["referer"], format!("http://{}/sensitive", mid_addr) @@ -182,7 +183,7 @@ async fn test_redirect_removes_sensitive_headers() { .unwrap() }); - tx.broadcast(Some(mid_server.addr())).unwrap(); + tx.send(Some(mid_server.addr())).unwrap(); reqwest::Client::builder() .build() diff --git a/tests/support/server.rs b/tests/support/server.rs index e64590494..4ac1a4a77 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -44,8 +44,7 @@ where { //Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { - let mut rt = runtime::Builder::new() - .basic_scheduler() + let rt = runtime::Builder::new_current_thread() .enable_all() .build() .expect("new rt"); diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 25c992471..35b105e93 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -11,7 +11,7 @@ async fn client_timeout() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; http::Response::default() } }); @@ -38,7 +38,7 @@ async fn request_timeout() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; http::Response::default() } }); @@ -94,7 +94,7 @@ async fn response_timeout() { async { // immediate response, but delayed body let body = hyper::Body::wrap_stream(futures_util::stream::once(async { - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; Ok::<_, std::convert::Infallible>("Hello") })); @@ -134,7 +134,7 @@ fn timeout_closes_connection() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; http::Response::default() } }); @@ -158,7 +158,7 @@ fn timeout_blocking_request() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; http::Response::default() } }); @@ -191,7 +191,7 @@ fn write_timeout_large_body() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::delay_for(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; http::Response::default() } });