Skip to content

Commit

Permalink
feat: support native-tls connections (bytebeamio#501)
Browse files Browse the repository at this point in the history
* feat: Add `native-tls` support (bytebeamio#378)

* feat: support for native-tls with custom config

* doc: add changelog for PR

* fix: `tls_connector` rename

Co-authored-by: David Mládek <david.mladek.cz@gmail.com>
  • Loading branch information
de-sh and mladedav committed Nov 15, 2022
1 parent bb03842 commit fc22724
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -2,6 +2,7 @@
---
rumqttc
-------
- Add support for native-tls within rumqttc (#501)
- Fixed panicking in `recv_timeout` and `try_recv` by entering tokio runtime context (#492, #497)
- Removed unused dependencies and updated version of some of used libraries to fix dependabots warning (#475)

Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 17 additions & 8 deletions rumqttc/Cargo.toml
Expand Up @@ -15,24 +15,33 @@ rustdoc-args = ["--cfg", "docsrs"]

[features]
default = ["use-rustls"]
websocket = ["async-tungstenite", "ws_stream_tungstenite", "http"]
use-rustls = ["tokio-rustls", "rustls-pemfile", "rustls-native-certs"]
use-native-tls = ["tokio-native-tls", "native-tls"]
websocket = ["async-tungstenite", "ws_stream_tungstenite", "http"]

[dependencies]
async-tungstenite = { version = "0.16", default-features = false, features = ["tokio-rustls-native-certs"], optional = true }
bytes = "1"
flume = "0.10"
futures = "0.3"
http = { version = "0.2", optional = true}
tokio = { version = "1.0", features = ["rt", "macros", "io-util", "net", "time"] }
bytes = "1.0"
log = "0.4"
pollster = "0.2"
rustls-pemfile = { version = "0.3", optional = true }
flume = "0.10"
thiserror = "1"
tokio = { version = "1", features = ["rt", "macros", "io-util", "net", "time"] }

# Optional
# rustls
tokio-rustls = { version = "0.23", optional = true }
rustls-pemfile = { version = "0.3", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
url = { version = "2", default-features = false, optional = true }
# websockets
async-tungstenite = { version = "0.16", default-features = false, features = ["tokio-rustls-native-certs"], optional = true }
ws_stream_tungstenite = { version = "0.7", default-features = false, features = ["tokio_io"], optional = true }
http = { version = "0.2", optional = true }
# native-tls
tokio-native-tls = { version = "0.3.0", optional = true }
native-tls = { version = "0.2.8", optional = true }
# url
url = { version = "2", default-features = false, optional = true }

[dev-dependencies]
color-backtrace = "0.4"
Expand Down
13 changes: 7 additions & 6 deletions rumqttc/src/eventloop.rs
@@ -1,7 +1,8 @@
use crate::framed::Network;
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
use crate::tls;
use crate::{Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport};
use crate::{framed::Network, Transport};
use crate::{Incoming, MqttState, Packet, Request, StateError};
use crate::{MqttOptions, Outgoing};

use crate::mqttbytes::v4::*;
#[cfg(feature = "websocket")]
Expand Down Expand Up @@ -37,7 +38,7 @@ pub enum ConnectionError {
#[cfg(feature = "websocket")]
#[error("Websocket Connect: {0}")]
WsConnect(#[from] http::Error),
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
#[error("TLS: {0}")]
Tls(#[from] tls::Error),
#[error("I/O: {0}")]
Expand Down Expand Up @@ -239,7 +240,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
let socket = TcpStream::connect((addr, port)).await?;
Network::new(socket, options.max_incoming_packet_size)
}
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
Transport::Tls(tls_config) => {
let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config).await?;
Network::new(socket, options.max_incoming_packet_size)
Expand Down Expand Up @@ -270,7 +271,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
.header("Sec-WebSocket-Protocol", "mqttv3.1")
.body(())?;

let connector = tls::tls_connector(&tls_config).await?;
let connector = tls::rustls_connector(&tls_config).await?;

let (socket, _) = connect_async_with_tls_connector(request, Some(connector)).await?;

Expand Down
23 changes: 18 additions & 5 deletions rumqttc/src/lib.rs
Expand Up @@ -108,7 +108,7 @@ mod eventloop;
mod framed;
pub mod mqttbytes;
mod state;
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
mod tls;
pub mod v5;

Expand All @@ -121,7 +121,7 @@ pub use mqttbytes::*;
#[cfg(feature = "use-rustls")]
use rustls_native_certs::load_native_certs;
pub use state::{MqttState, StateError};
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
pub use tls::Error as TlsError;
#[cfg(feature = "use-rustls")]
pub use tokio_rustls;
Expand Down Expand Up @@ -204,7 +204,7 @@ impl From<Unsubscribe> for Request {
#[derive(Clone)]
pub enum Transport {
Tcp,
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
Tls(TlsConfiguration),
#[cfg(unix)]
Unix,
Expand Down Expand Up @@ -249,7 +249,7 @@ impl Transport {
Self::tls_with_config(config)
}

#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
pub fn tls_with_config(tls_config: TlsConfiguration) -> Self {
Self::Tls(tls_config)
}
Expand Down Expand Up @@ -298,8 +298,9 @@ impl Transport {

/// TLS configuration method
#[derive(Clone)]
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
pub enum TlsConfiguration {
#[cfg(feature = "use-rustls")]
Simple {
/// connection method
ca: Vec<u8>,
Expand All @@ -308,8 +309,20 @@ pub enum TlsConfiguration {
/// tls client_authentication
client_auth: Option<(Vec<u8>, Key)>,
},
#[cfg(feature = "use-native-tls")]
SimpleNative {
/// ca certificate
ca: Vec<u8>,
/// pkcs12 binary der
der: Vec<u8>,
/// password for use with der
password: String,
},
#[cfg(feature = "use-rustls")]
/// Injected rustls ClientConfig for TLS, to allow more customisation.
Rustls(Arc<ClientConfig>),
#[cfg(feature = "use-native-tls")]
Native,
}

#[cfg(feature = "use-rustls")]
Expand Down
98 changes: 80 additions & 18 deletions rumqttc/src/tls.rs
@@ -1,21 +1,39 @@
use tokio::net::TcpStream;

#[cfg(feature = "use-rustls")]
use tokio_rustls::rustls;
#[cfg(feature = "use-rustls")]
use tokio_rustls::rustls::client::InvalidDnsNameError;
#[cfg(feature = "use-rustls")]
use tokio_rustls::rustls::{
Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName,
};
#[cfg(feature = "use-rustls")]
use tokio_rustls::webpki;
use tokio_rustls::{client::TlsStream, TlsConnector};

use crate::{Key, TlsConfiguration};
#[cfg(feature = "use-rustls")]
use tokio_rustls::TlsConnector as RustlsConnector;

#[cfg(feature = "use-rustls")]
use crate::Key;
#[cfg(feature = "use-rustls")]
use std::convert::TryFrom;
use std::io;
#[cfg(feature = "use-rustls")]
use std::io::{BufReader, Cursor};
use std::net::AddrParseError;
#[cfg(feature = "use-rustls")]
use std::sync::Arc;

/// TLS backend error
use crate::framed::N;
use crate::TlsConfiguration;

#[cfg(feature = "use-native-tls")]
use tokio_native_tls::TlsConnector as NativeTlsConnector;

#[cfg(feature = "use-native-tls")]
use tokio_native_tls::native_tls::{Error as NativeTlsError, Identity};

use std::io;
use std::net::AddrParseError;

#[derive(Debug, thiserror::Error)]
pub enum Error {
/// Error parsing IP address
Expand All @@ -24,28 +42,36 @@ pub enum Error {
/// I/O related error
#[error("I/O: {0}")]
Io(#[from] io::Error),
#[cfg(feature = "use-rustls")]
/// Certificate/Name validation error
#[error("Web Pki: {0}")]
WebPki(#[from] webpki::Error),
#[cfg(feature = "use-rustls")]
/// Invalid DNS name
#[error("DNS name")]
DNSName(#[from] InvalidDnsNameError),
#[cfg(feature = "use-rustls")]
/// Error from rustls module
#[error("TLS error: {0}")]
TLS(#[from] rustls::Error),
#[cfg(feature = "use-rustls")]
/// No valid certificate in chain
#[error("No valid certificate in chain")]
NoValidCertInChain,
#[cfg(feature = "use-native-tls")]
#[error("Native TLS error {0}")]
NativeTls(#[from] NativeTlsError),
}

// The cert handling functions return unit right now, this is a shortcut
impl From<()> for Error {
fn from(_: ()) -> Self {
Error::NoValidCertInChain
}
}
// // The cert handling functions return unit right now, this is a shortcut
// impl From<()> for Error {
// fn from(_: ()) -> Self {
// Error::NoValidCertInChain
// }
// }

pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result<TlsConnector, Error> {
#[cfg(feature = "use-rustls")]
pub async fn rustls_connector(tls_config: &TlsConfiguration) -> Result<RustlsConnector, Error> {
let config = match tls_config {
TlsConfiguration::Simple {
ca,
Expand Down Expand Up @@ -118,19 +144,55 @@ pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result<TlsConnector
Arc::new(config)
}
TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
#[allow(unreachable_patterns)]
_ => unreachable!("This cannot be called for other TLS backends than Rustls"),
};

Ok(TlsConnector::from(config))
Ok(RustlsConnector::from(config))
}

#[cfg(feature = "use-native-tls")]
pub async fn native_tls_connector(
tls_config: &TlsConfiguration,
) -> Result<NativeTlsConnector, Error> {
let connector = match tls_config {
TlsConfiguration::SimpleNative { ca, der, password } => {
let cert = native_tls::Certificate::from_pem(ca)?;
let identity = Identity::from_pkcs12(der, password)?;
native_tls::TlsConnector::builder()
.add_root_certificate(cert)
.identity(identity)
.build()?
}
TlsConfiguration::Native => native_tls::TlsConnector::new()?,
#[allow(unreachable_patterns)]
_ => unreachable!("This cannot be called for other TLS backends than Native TLS"),
};

Ok(connector.into())
}

pub async fn tls_connect(
addr: &str,
port: u16,
tls_config: &TlsConfiguration,
) -> Result<TlsStream<TcpStream>, Error> {
let connector = tls_connector(tls_config).await?;
let domain = ServerName::try_from(addr)?;
) -> Result<Box<dyn N>, Error> {
let tcp = TcpStream::connect((addr, port)).await?;
let tls = connector.connect(domain, tcp).await?;

let tls: Box<dyn N> = match tls_config {
#[cfg(feature = "use-rustls")]
TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
let connector = rustls_connector(tls_config).await?;
let domain = ServerName::try_from(addr)?;
Box::new(connector.connect(domain, tcp).await?)
}
#[cfg(feature = "use-native-tls")]
TlsConfiguration::Native | TlsConfiguration::SimpleNative { .. } => {
let connector = native_tls_connector(tls_config).await?;
Box::new(connector.connect(addr, tcp).await?)
}
#[allow(unreachable_patterns)]
_ => panic!("Unknown or not enabled TLS backend configuration"),
};
Ok(tls)
}
8 changes: 4 additions & 4 deletions rumqttc/src/v5/eventloop.rs
@@ -1,7 +1,7 @@
use super::framed::Network;
use super::mqttbytes::{v5::*, *};
use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
use crate::tls;

#[cfg(feature = "websocket")]
Expand Down Expand Up @@ -39,7 +39,7 @@ pub enum ConnectionError {
#[cfg(feature = "websocket")]
#[error("Websocket Connect: {0}")]
WsConnect(#[from] http::Error),
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
#[error("TLS: {0}")]
Tls(#[from] tls::Error),
#[error("I/O: {0}")]
Expand Down Expand Up @@ -241,7 +241,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
let socket = TcpStream::connect((addr, port)).await?;
Network::new(socket, options.max_incoming_packet_size)
}
#[cfg(feature = "use-rustls")]
#[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
Transport::Tls(tls_config) => {
let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config).await?;
Network::new(socket, options.max_incoming_packet_size)
Expand Down Expand Up @@ -272,7 +272,7 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
.header("Sec-WebSocket-Protocol", "mqttv3.1")
.body(())?;

let connector = tls::tls_connector(&tls_config).await?;
let connector = tls::rustls_connector(&tls_config).await?;

let (socket, _) = connect_async_with_tls_connector(request, Some(connector)).await?;

Expand Down

0 comments on commit fc22724

Please sign in to comment.