From 18fd9a63b0eb7bf51d2e2b7fe31b4567f0b05779 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 19 Dec 2019 11:43:03 -0800 Subject: [PATCH] Re-enable rustls (#747) --- .github/workflows/ci.yml | 12 ++-- Cargo.toml | 12 ++-- src/connect.rs | 129 +++++++++++++++++++++++++++++++++------ 3 files changed, 123 insertions(+), 30 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca54f8342..004b50b9c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,8 +54,8 @@ jobs: - windows / stable-x86_64-gnu - windows / stable-i686-gnu - "feat.: default-tls disabled" - # - "feat.: rustls-tls" - # - "feat.: default-tls and rustls-tls" + - "feat.: rustls-tls" + - "feat.: default-tls and rustls-tls" - "feat.: cookies" - "feat.: blocking" - "feat.: gzip" @@ -98,10 +98,10 @@ jobs: - name: "feat.: default-tls disabled" features: "--no-default-features" - # - name: "feat.: rustls-tls - # features: "--no-default-features --features rustls-tls" - # - name: "feat.: default-tls and rustls-tls" - # features: "--features rustls-tls" + - name: "feat.: rustls-tls" + features: "--no-default-features --features rustls-tls" + - name: "feat.: default-tls and rustls-tls" + features: "--features rustls-tls" - name: "feat.: cookies" features: "--features cookies" - name: "feat.: blocking" diff --git a/Cargo.toml b/Cargo.toml index be83beca7..bef47582b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ tls = [] default-tls = ["hyper-tls", "native-tls", "tls", "tokio-tls"] default-tls-vendored = ["default-tls", "native-tls/vendored"] -#rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] +rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] blocking = ["futures-channel", "futures-util/io", "tokio/rt-threaded", "tokio/rt-core"] @@ -77,11 +77,11 @@ hyper-tls = { version = "0.4", optional = true } native-tls = { version = "0.2", optional = true } tokio-tls = { version = "0.3.0", optional = true } -## rustls-tls -#hyper-rustls = { version = "=0.18.0-alpha.1", optional = true } -#rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true } -#tokio-rustls = { version = "=0.12.0-alpha.2", optional = true } -#webpki-roots = { version = "0.17", optional = true } +# rustls-tls +hyper-rustls = { version = "0.19", optional = true } +rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true } +tokio-rustls = { version = "0.12", optional = true } +webpki-roots = { version = "0.17", optional = true } ## blocking futures-channel = { version = "0.3.0", optional = true } diff --git a/src/connect.rs b/src/connect.rs index 81ba84d9e..d8e2fe1ac 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -25,6 +25,8 @@ use crate::proxy::{Proxy, ProxyScheme}; use crate::error::BoxError; #[cfg(feature = "default-tls")] use self::native_tls_conn::NativeTlsConn; +#[cfg(feature = "rustls-tls")] +use self::rustls_tls_conn::RustlsTlsConn; //#[cfg(feature = "trust-dns")] //type HttpConnector = hyper::client::HttpConnector; @@ -244,12 +246,13 @@ impl Connector { // Disable Nagle's algorithm for TLS handshake // // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES - http.set_nodelay(no_delay || (dst.scheme() == Some(&Scheme::HTTPS))); + http.set_nodelay(self.nodelay || (dst.scheme() == Some(&Scheme::HTTPS))); + + let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone())); + let io = http.call(dst).await?; - let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); - let io = http.connect(dst).await?; if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { - if !no_delay { + if !self.nodelay { let (io, _) = stream.get_ref(); io.set_nodelay(false)?; } @@ -320,35 +323,32 @@ impl Connector { tls_proxy, } => { if dst.scheme() == Some(&Scheme::HTTPS) { - use rustls::Session; use tokio_rustls::webpki::DNSNameRef; use tokio_rustls::TlsConnector as RustlsConnector; - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); + let host = dst.host() + .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? + .to_string(); + let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); let mut http = http.clone(); - http.set_nodelay(no_delay); - let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); + http.set_nodelay(self.nodelay); + let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); let tls = tls.clone(); - let (conn, connected) = http.connect(proxy_dst).await?; + let conn = http.call(proxy_dst).await?; log::trace!("tunneling HTTPS over proxy"); let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) .map(|dnsname| dnsname.to_owned()) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); - let tunneled = tunnel(conn, host, port, auth).await?; + let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; let dnsname = maybe_dnsname?; let io = RustlsConnector::from(tls) .connect(dnsname.as_ref(), tunneled) .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - }; + return Ok(Conn { - inner: Box::new(io), - connected: Connected::new(), + inner: Box::new(RustlsTlsConn { inner: io }), + is_proxy: false, }); } } @@ -682,6 +682,99 @@ mod native_tls_conn { } } +#[cfg(feature = "rustls-tls")] +mod rustls_tls_conn { + use rustls::Session; + use std::mem::MaybeUninit; + use std::{pin::Pin, task::{Context, Poll}}; + use bytes::{Buf, BufMut}; + use hyper::client::connect::{Connected, Connection}; + use pin_project_lite::pin_project; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio_rustls::client::TlsStream; + + + pin_project! { + pub(super) struct RustlsTlsConn { + #[pin] pub(super) inner: TlsStream, + } + } + + impl Connection for RustlsTlsConn { + fn connected(&self) -> Connected { + if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") { + self.inner.get_ref().0.connected().negotiated_h2() + } else { + self.inner.get_ref().0.connected() + } + } + } + + impl AsyncRead for RustlsTlsConn { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8] + ) -> Poll> { + let this = self.project(); + AsyncRead::poll_read(this.inner, cx, buf) + } + + unsafe fn prepare_uninitialized_buffer( + &self, + buf: &mut [MaybeUninit] + ) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read_buf( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut B + ) -> Poll> + where + Self: Sized + { + let this = self.project(); + AsyncRead::poll_read_buf(this.inner, cx, buf) + } + } + + impl AsyncWrite for RustlsTlsConn { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8] + ) -> Poll> { + let this = self.project(); + AsyncWrite::poll_write(this.inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + AsyncWrite::poll_flush(this.inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context + ) -> Poll> { + let this = self.project(); + AsyncWrite::poll_shutdown(this.inner, cx) + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut B + ) -> Poll> where + Self: Sized { + let this = self.project(); + AsyncWrite::poll_write_buf(this.inner, cx, buf) + } + } +} + #[cfg(feature = "socks")] mod socks { use std::io;