Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow choosing rustls certs through flags #184

Merged
merged 1 commit into from Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 4 additions & 13 deletions .github/workflows/ci.yml
Expand Up @@ -39,20 +39,11 @@ jobs:
- name: Install dependencies
run: sudo apt-get install libssl-dev

- name: Check no-default-features
run: cargo check --no-default-features
- name: Install cargo-hack
run: cargo install cargo-hack

- name: Check default-features
run: cargo check

- name: Check native-tls
run: cargo check --features native-tls

- name: Check rustls
run: cargo check --features rustls-tls

- name: Check native-tls and rustls
run: cargo check --features native-tls,rustls-tls
- name: Check
run: cargo hack check --feature-powerset --all-targets

- name: Test
run: cargo test --release
Expand Down
12 changes: 9 additions & 3 deletions Cargo.toml
Expand Up @@ -13,14 +13,16 @@ edition = "2018"
include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"]

[package.metadata.docs.rs]
features = ["native-tls", "rustls-tls"]
features = ["native-tls", "__rustls-tls"]

[features]
default = ["connect"]
connect = ["stream", "tokio/net"]
native-tls = ["native-tls-crate", "tokio-native-tls", "stream", "tungstenite/native-tls"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored", "tungstenite/native-tls-vendored"]
rustls-tls = ["rustls", "tokio-rustls", "stream", "tungstenite/rustls-tls", "webpki", "webpki-roots"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "tokio-rustls", "stream", "tungstenite/__rustls-tls", "webpki"]
stream = []

[dependencies]
Expand All @@ -30,7 +32,7 @@ pin-project = "1.0"
tokio = { version = "1.0.0", default-features = false, features = ["io-util"] }

[dependencies.tungstenite]
version = "0.14.0"
version = "0.15.0"
default-features = false

[dependencies.native-tls-crate]
Expand All @@ -42,6 +44,10 @@ version = "0.2.7"
optional = true
version = "0.19.0"

[dependencies.rustls-native-certs]
optional = true
version = "0.5.0"

[dependencies.tokio-native-tls]
optional = true
version = "0.3.0"
Expand Down
4 changes: 2 additions & 2 deletions src/connect.rs
Expand Up @@ -48,12 +48,12 @@ where
let try_socket = TcpStream::connect(addr).await;
let socket = try_socket.map_err(Error::Io)?;

#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
crate::client_async_with_config(request, MaybeTlsStream::Plain(socket), config).await
}

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
{
crate::tls::client_async_tls_with_config(request, socket, config, None).await
}
Expand Down
9 changes: 4 additions & 5 deletions src/lib.rs
Expand Up @@ -18,7 +18,7 @@ mod connect;
mod handshake;
#[cfg(feature = "stream")]
mod stream;
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls;

use std::io::{Read, Write};
Expand All @@ -44,10 +44,9 @@ use tungstenite::{
HandshakeError,
},
protocol::{Message, Role, WebSocket, WebSocketConfig},
server,
};

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_async_tls, client_async_tls_with_config, Connector};

#[cfg(feature = "connect")]
Expand Down Expand Up @@ -161,7 +160,7 @@ where
C: Callback + Unpin,
{
let f = handshake::server_handshake(stream, move |allow_std| {
server::accept_hdr_with_config(allow_std, callback, config)
tungstenite::accept_hdr_with_config(allow_std, callback, config)
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
Expand Down Expand Up @@ -324,7 +323,7 @@ where
}

/// Get a domain from an URL.
#[cfg(any(feature = "connect", feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
#[inline]
fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
match request.uri().host() {
Expand Down
10 changes: 5 additions & 5 deletions src/stream.rs
Expand Up @@ -22,7 +22,7 @@ pub enum MaybeTlsStream<S> {
#[cfg(feature = "native-tls")]
NativeTls(tokio_native_tls::TlsStream<S>),
/// Encrypted socket stream using `rustls`.
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Rustls(tokio_rustls::client::TlsStream<S>),
}

Expand All @@ -36,7 +36,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_read(cx, buf),
}
}
Expand All @@ -52,7 +52,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_write(cx, buf),
}
}
Expand All @@ -62,7 +62,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}
Expand All @@ -75,7 +75,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
Expand Down
38 changes: 26 additions & 12 deletions src/tls.rs
Expand Up @@ -20,7 +20,7 @@ pub enum Connector {
#[cfg(feature = "native-tls")]
NativeTls(native_tls_crate::TlsConnector),
/// `rustls` TLS connector.
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Rustls(std::sync::Arc<rustls::ClientConfig>),
}

Expand Down Expand Up @@ -61,7 +61,7 @@ mod encryption {
}
}

#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
pub mod rustls {
pub use rustls::ClientConfig;
use tokio_rustls::{webpki::DNSNameRef, TlsConnector as TokioTlsConnector};
Expand All @@ -85,12 +85,26 @@ mod encryption {
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let config = tls_connector.unwrap_or_else(|| {
let mut config = ClientConfig::new();
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);

Arc::new(config)
});
let config = match tls_connector {
Some(config) => config,
None => {
#[allow(unused_mut)]
let mut config = ClientConfig::new();
#[cfg(feature = "rustls-tls-native-roots")]
{
config.root_store = rustls_native_certs::load_native_certs()
.map_err(|(_, err)| err)?;
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
}

Arc::new(config)
}
};
let domain = DNSNameRef::try_from_ascii_str(&domain).map_err(TlsError::Dns)?;
let stream = TokioTlsConnector::from(config);
let connected = stream.connect(domain, socket).await;
Expand Down Expand Up @@ -158,7 +172,7 @@ where
{
let request = request.into_client_request()?;

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
let domain = domain(&request)?;

// Make sure we check domain and mode first. URL must be valid.
Expand All @@ -170,7 +184,7 @@ where
Connector::NativeTls(conn) => {
self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
}
#[cfg(feature = "rustls-tls")]
#[cfg(feature = "__rustls-tls")]
Connector::Rustls(conn) => {
self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
}
Expand All @@ -181,11 +195,11 @@ where
{
self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
{
self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
self::encryption::plain::wrap_stream(stream, mode).await
}
Expand Down