Skip to content

Commit

Permalink
Re-enable rustls (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed Dec 19, 2019
1 parent f78846b commit 18fd9a6
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 30 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Expand Up @@ -54,8 +54,8 @@ jobs:
- windows / stable-x86_64-gnu
- windows / stable-i686-gnu
- "feat.: default-tls disabled"
# - "feat.: rustls-tls"
# - "feat.: default-tls and rustls-tls"
- "feat.: rustls-tls"
- "feat.: default-tls and rustls-tls"
- "feat.: cookies"
- "feat.: blocking"
- "feat.: gzip"
Expand Down Expand Up @@ -98,10 +98,10 @@ jobs:

- name: "feat.: default-tls disabled"
features: "--no-default-features"
# - name: "feat.: rustls-tls
# features: "--no-default-features --features rustls-tls"
# - name: "feat.: default-tls and rustls-tls"
# features: "--features rustls-tls"
- name: "feat.: rustls-tls"
features: "--no-default-features --features rustls-tls"
- name: "feat.: default-tls and rustls-tls"
features: "--features rustls-tls"
- name: "feat.: cookies"
features: "--features cookies"
- name: "feat.: blocking"
Expand Down
12 changes: 6 additions & 6 deletions Cargo.toml
Expand Up @@ -23,7 +23,7 @@ tls = []
default-tls = ["hyper-tls", "native-tls", "tls", "tokio-tls"]
default-tls-vendored = ["default-tls", "native-tls/vendored"]

#rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"]
rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"]

blocking = ["futures-channel", "futures-util/io", "tokio/rt-threaded", "tokio/rt-core"]

Expand Down Expand Up @@ -77,11 +77,11 @@ hyper-tls = { version = "0.4", optional = true }
native-tls = { version = "0.2", optional = true }
tokio-tls = { version = "0.3.0", optional = true }

## rustls-tls
#hyper-rustls = { version = "=0.18.0-alpha.1", optional = true }
#rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true }
#tokio-rustls = { version = "=0.12.0-alpha.2", optional = true }
#webpki-roots = { version = "0.17", optional = true }
# rustls-tls
hyper-rustls = { version = "0.19", optional = true }
rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.12", optional = true }
webpki-roots = { version = "0.17", optional = true }

## blocking
futures-channel = { version = "0.3.0", optional = true }
Expand Down
129 changes: 111 additions & 18 deletions src/connect.rs
Expand Up @@ -25,6 +25,8 @@ use crate::proxy::{Proxy, ProxyScheme};
use crate::error::BoxError;
#[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn;
#[cfg(feature = "rustls-tls")]
use self::rustls_tls_conn::RustlsTlsConn;

//#[cfg(feature = "trust-dns")]
//type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>;
Expand Down Expand Up @@ -244,12 +246,13 @@ impl Connector {
// Disable Nagle's algorithm for TLS handshake
//
// https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
http.set_nodelay(no_delay || (dst.scheme() == Some(&Scheme::HTTPS)));
http.set_nodelay(self.nodelay || (dst.scheme() == Some(&Scheme::HTTPS)));

let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
let io = http.call(dst).await?;

let http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
let io = http.connect(dst).await?;
if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io {
if !no_delay {
if !self.nodelay {
let (io, _) = stream.get_ref();
io.set_nodelay(false)?;
}
Expand Down Expand Up @@ -320,35 +323,32 @@ impl Connector {
tls_proxy,
} => {
if dst.scheme() == Some(&Scheme::HTTPS) {
use rustls::Session;
use tokio_rustls::webpki::DNSNameRef;
use tokio_rustls::TlsConnector as RustlsConnector;

let host = dst.host().to_owned();
let port = dst.port().unwrap_or(443);
let host = dst.host()
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
.to_string();
let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
let mut http = http.clone();
http.set_nodelay(no_delay);
let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
http.set_nodelay(self.nodelay);
let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
let tls = tls.clone();
let (conn, connected) = http.connect(proxy_dst).await?;
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
let tunneled = tunnel(conn, host, port, auth).await?;
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
let dnsname = maybe_dnsname?;
let io = RustlsConnector::from(tls)
.connect(dnsname.as_ref(), tunneled)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
};

return Ok(Conn {
inner: Box::new(io),
connected: Connected::new(),
inner: Box::new(RustlsTlsConn { inner: io }),
is_proxy: false,
});
}
}
Expand Down Expand Up @@ -682,6 +682,99 @@ mod native_tls_conn {
}
}

#[cfg(feature = "rustls-tls")]
mod rustls_tls_conn {
use rustls::Session;
use std::mem::MaybeUninit;
use std::{pin::Pin, task::{Context, Poll}};
use bytes::{Buf, BufMut};
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::client::TlsStream;


pin_project! {
pub(super) struct RustlsTlsConn<T> {
#[pin] pub(super) inner: TlsStream<T>,
}
}

impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
fn connected(&self) -> Connected {
if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") {
self.inner.get_ref().0.connected().negotiated_h2()
} else {
self.inner.get_ref().0.connected()
}
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}

unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}

fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<tokio::io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8]
) -> Poll<Result<usize, tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_write(this.inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context
) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}

fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
}

#[cfg(feature = "socks")]
mod socks {
use std::io;
Expand Down

0 comments on commit 18fd9a6

Please sign in to comment.