Skip to content

Commit

Permalink
Allow choosing rustls certs through flags
Browse files Browse the repository at this point in the history
  • Loading branch information
dnaka91 committed Aug 10, 2021
1 parent 2546fa0 commit 9653bca
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 26 deletions.
12 changes: 10 additions & 2 deletions Cargo.toml
Expand Up @@ -12,14 +12,16 @@ version = "0.15.0"
edition = "2018"

[package.metadata.docs.rs]
features = ["native-tls", "rustls-tls"]
features = ["native-tls", "__rustls-tls"]

[features]
default = ["connect"]
connect = ["stream", "tokio/net"]
native-tls = ["native-tls-crate", "tokio-native-tls", "stream", "tungstenite/native-tls"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored", "tungstenite/native-tls-vendored"]
rustls-tls = ["rustls", "tokio-rustls", "stream", "tungstenite/rustls-tls", "webpki", "webpki-roots"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "tokio-rustls", "stream", "tungstenite/__rustls-tls", "webpki"]
stream = []

[dependencies]
Expand All @@ -31,6 +33,8 @@ tokio = { version = "1.0.0", default-features = false, features = ["io-util"] }
[dependencies.tungstenite]
version = "0.14.0"
default-features = false
git = "https://github.com/snapview/tungstenite-rs.git"
rev = "32450ae5af0070c7f3fd8341d244525119672506"

[dependencies.native-tls-crate]
optional = true
Expand All @@ -41,6 +45,10 @@ version = "0.2.7"
optional = true
version = "0.19.0"

[dependencies.rustls-native-certs]
optional = true
version = "0.5.0"

[dependencies.tokio-native-tls]
optional = true
version = "0.3.0"
Expand Down
4 changes: 2 additions & 2 deletions src/connect.rs
Expand Up @@ -45,12 +45,12 @@ where
let try_socket = TcpStream::connect(addr).await;
let socket = try_socket.map_err(Error::Io)?;

#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
crate::client_async_with_config(request, MaybeTlsStream::Plain(socket), config).await
}

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
{
crate::tls::client_async_tls_with_config(request, socket, config, None).await
}
Expand Down
9 changes: 4 additions & 5 deletions src/lib.rs
Expand Up @@ -18,7 +18,7 @@ mod connect;
mod handshake;
#[cfg(feature = "stream")]
mod stream;
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls;

use std::io::{Read, Write};
Expand All @@ -44,10 +44,9 @@ use tungstenite::{
HandshakeError,
},
protocol::{Message, Role, WebSocket, WebSocketConfig},
server,
};

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_async_tls, client_async_tls_with_config, Connector};

#[cfg(feature = "connect")]
Expand Down Expand Up @@ -158,7 +157,7 @@ where
C: Callback + Unpin,
{
let f = handshake::server_handshake(stream, move |allow_std| {
server::accept_hdr_with_config(allow_std, callback, config)
tungstenite::accept_hdr_with_config(allow_std, callback, config)
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
Expand Down Expand Up @@ -321,7 +320,7 @@ where
}

/// Get a domain from an URL.
#[cfg(any(feature = "connect", feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
#[inline]
fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
match request.uri().host() {
Expand Down
10 changes: 5 additions & 5 deletions src/stream.rs
Expand Up @@ -22,7 +22,7 @@ pub enum MaybeTlsStream<S> {
#[cfg(feature = "native-tls")]
NativeTls(tokio_native_tls::TlsStream<S>),
/// Encrypted socket stream using `rustls`.
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Rustls(tokio_rustls::client::TlsStream<S>),
}

Expand All @@ -36,7 +36,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_read(cx, buf),
}
}
Expand All @@ -52,7 +52,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_write(cx, buf),
}
}
Expand All @@ -62,7 +62,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}
Expand All @@ -75,7 +75,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
Expand Down
38 changes: 26 additions & 12 deletions src/tls.rs
Expand Up @@ -20,7 +20,7 @@ pub enum Connector {
#[cfg(feature = "native-tls")]
NativeTls(native_tls_crate::TlsConnector),
/// `rustls` TLS connector.
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Rustls(std::sync::Arc<rustls::ClientConfig>),
}

Expand Down Expand Up @@ -61,7 +61,7 @@ mod encryption {
}
}

#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
pub mod rustls {
pub use rustls::ClientConfig;
use tokio_rustls::{webpki::DNSNameRef, TlsConnector as TokioTlsConnector};
Expand All @@ -85,12 +85,26 @@ mod encryption {
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let config = tls_connector.unwrap_or_else(|| {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);

Arc::new(config)
});
let config = match tls_connector {
Some(config) => config,
None => {
#[allow(unused_mut)]
let mut config = ClientConfig::new();
#[cfg(feature = "rustls-tls-native-roots")]
{
config.root_store = rustls_native_certs::load_native_certs()
.map_err(|(_, err)| err)?;
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
}

Arc::new(config)
}
};
let domain = DNSNameRef::try_from_ascii_str(&domain).map_err(TlsError::Dns)?;
let stream = TokioTlsConnector::from(config);
let connected = stream.connect(domain, socket).await;
Expand Down Expand Up @@ -158,7 +172,7 @@ where
{
let request = request.into_client_request()?;

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
let domain = domain(&request)?;

// Make sure we check domain and mode first. URL must be valid.
Expand All @@ -170,7 +184,7 @@ where
Connector::NativeTls(conn) => {
self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
}
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Connector::Rustls(conn) => {
self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
}
Expand All @@ -181,11 +195,11 @@ where
{
self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
{
self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
self::encryption::plain::wrap_stream(stream, mode).await
}
Expand Down

0 comments on commit 9653bca

Please sign in to comment.