Skip to content

Commit

Permalink
feat(tls): upgrade to tokio-rustls 0.23 (rustls 0.20)
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Dec 7, 2021
1 parent 5a88e11 commit 6d17e58
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 70 deletions.
9 changes: 5 additions & 4 deletions tonic/Cargo.toml
Expand Up @@ -27,7 +27,7 @@ codegen = ["async-trait"]
compression = ["flate2"]
default = ["transport", "codegen", "prost"]
prost = ["prost1", "prost-derive"]
tls = ["transport", "tokio-rustls"]
tls = ["rustls-pemfile", "transport", "tokio-rustls"]
tls-roots = ["tls-roots-common", "rustls-native-certs"]
tls-roots-common = ["tls"]
tls-webpki-roots = ["tls-roots-common", "webpki-roots"]
Expand Down Expand Up @@ -79,9 +79,10 @@ tower = {version = "0.4.7", features = ["balance", "buffer", "discover", "limit"
tracing-futures = {version = "0.2", optional = true}

# rustls
rustls-native-certs = {version = "0.5", optional = true}
tokio-rustls = {version = "0.22", optional = true}
webpki-roots = {version = "0.21.1", optional = true}
rustls-pemfile = { version = "0.2.1", optional = true }
rustls-native-certs = { version = "0.6.1", optional = true }
tokio-rustls = { version = "0.23.1", optional = true }
webpki-roots = { version = "0.22.1", optional = true }

# compression
flate2 = {version = "1.0", optional = true}
Expand Down
6 changes: 3 additions & 3 deletions tonic/src/transport/server/conn.rs
Expand Up @@ -7,7 +7,7 @@ use crate::transport::Certificate;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};
use tokio_rustls::server::TlsStream;

/// Trait that connected IO resources implement and use to produce info about the connection.
///
Expand Down Expand Up @@ -115,10 +115,10 @@ where
let (inner, session) = self.get_ref();
let inner = inner.connect_info();

let certs = if let Some(certs) = session.get_peer_certificates() {
let certs = if let Some(certs) = session.peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.map(|c| Certificate::from_pem(c))
.collect();
Some(Arc::new(certs))
} else {
Expand Down
22 changes: 10 additions & 12 deletions tonic/src/transport/service/connector.rs
Expand Up @@ -3,6 +3,8 @@ use super::io::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use http::Uri;
#[cfg(feature = "tls-roots-common")]
use std::convert::TryInto;
use std::task::{Context, Poll};
use tower::make::MakeConnection;
use tower_service::Service;
Expand Down Expand Up @@ -39,22 +41,18 @@ impl<C> Connector<C> {

#[cfg(feature = "tls-roots-common")]
fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> {
use tokio_rustls::webpki::DNSNameRef;

if self.tls.is_some() {
return self.tls.clone();
}

match (scheme, host) {
(Some("https"), Some(host)) => {
if DNSNameRef::try_from_ascii(host.as_bytes()).is_ok() {
TlsConnector::new_with_rustls_cert(None, None, host.to_owned()).ok()
} else {
None
}
}
_ => None,
}
let host = match (scheme, host) {
(Some("https"), Some(host)) => host,
_ => return None,
};

host.try_into()
.ok()
.and_then(|dns| TlsConnector::new_with_rustls_cert(None, None, dns).ok())
}
}

Expand Down
119 changes: 68 additions & 51 deletions tonic/src/transport/service/tls.rs
Expand Up @@ -5,12 +5,13 @@ use crate::transport::{
};
#[cfg(feature = "tls-roots")]
use rustls_native_certs;
#[cfg(feature = "tls")]
use std::convert::TryInto;
use std::{fmt, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use tokio_rustls::{
rustls::{ClientConfig, NoClientAuth, ServerConfig, Session},
webpki::DNSNameRef,
rustls::{ClientConfig, RootCertStore, ServerConfig, ServerName},
TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
};

Expand All @@ -31,7 +32,7 @@ enum TlsError {
#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<String>,
domain: Arc<ServerName>,
}

impl TlsConnector {
Expand All @@ -41,38 +42,47 @@ impl TlsConnector {
identity: Option<Identity>,
domain: String,
) -> Result<Self, crate::Error> {
let mut config = ClientConfig::new();
config.set_protocols(&[Vec::from(ALPN_H2)]);

if let Some(identity) = identity {
let (client_cert, client_key) = rustls_keys::load_identity(identity)?;
config.set_single_client_cert(client_cert, client_key)?;
}
let builder = ClientConfig::builder().with_safe_defaults();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
{
config.root_store = match rustls_native_certs::load_native_certs() {
Ok(store) | Err((Some(store), _)) => store,
Err((None, error)) => return Err(error.into()),
match rustls_native_certs::load_native_certs() {
Ok(certs) => roots.add_parsable_certificates(
&certs.into_iter().map(|cert| cert.0).collect::<Vec<_>>(),
),
Err(error) => return Err(error.into()),
};
}

#[cfg(feature = "tls-webpki-roots")]
{
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
use tokio_rustls::rustls::OwnedTrustAnchor;

roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
}

if let Some(cert) = ca_cert {
let mut buf = std::io::Cursor::new(&cert.pem[..]);
config.root_store.add_pem_file(&mut buf).unwrap();
rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?;
}

Ok(Self {
config: Arc::new(config),
domain: Arc::new(domain),
})
let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = rustls_keys::load_identity(identity)?;
builder.with_single_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec());
Self::new_with_rustls_raw(config, domain)
}

#[cfg(feature = "tls")]
Expand All @@ -82,7 +92,7 @@ impl TlsConnector {
) -> Result<Self, crate::Error> {
Ok(Self {
config: Arc::new(config),
domain: Arc::new(domain),
domain: Arc::new(domain.as_str().try_into()?),
})
}

Expand All @@ -91,15 +101,13 @@ impl TlsConnector {
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let tls_io = {
let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned();

let io = RustlsConnector::from(self.config.clone())
.connect(dns.as_ref(), io)
.connect(self.domain.as_ref().to_owned(), io)
.await?;

let (_, session) = io.get_ref();

match session.get_alpn_protocol() {
match session.alpn_protocol() {
Some(b) if b == b"h2" => (),
_ => return Err(TlsError::H2NotNegotiated.into()),
};
Expand Down Expand Up @@ -128,26 +136,22 @@ impl TlsAcceptor {
identity: Identity,
client_ca_root: Option<Certificate>,
) -> Result<Self, crate::Error> {
let (cert, key) = rustls_keys::load_identity(identity)?;
let builder = ServerConfig::builder().with_safe_defaults();

let mut config = match client_ca_root {
None => ServerConfig::new(NoClientAuth::new()),
let builder = match client_ca_root {
None => builder.with_no_client_auth(),
Some(cert) => {
let mut cert = std::io::Cursor::new(&cert.pem[..]);

let mut client_root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
if client_root_cert_store.add_pem_file(&mut cert).is_err() {
return Err(Box::new(TlsError::CertificateParseError));
}

let client_auth =
tokio_rustls::rustls::AllowAnyAuthenticatedClient::new(client_root_cert_store);
ServerConfig::new(client_auth)
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
let mut roots = RootCertStore::empty();
rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?;
builder.with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots))
}
};
config.set_single_cert(cert, key)?;
config.set_protocols(&[Vec::from(ALPN_H2)]);

let (cert, key) = rustls_keys::load_identity(identity)?;
let mut config = builder.with_single_cert(cert, key)?;

config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec());
Ok(Self {
inner: Arc::new(config),
})
Expand Down Expand Up @@ -194,7 +198,9 @@ impl std::error::Error for TlsError {}

#[cfg(feature = "tls")]
mod rustls_keys {
use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey};
use std::io::Cursor;

use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore};

use crate::transport::service::tls::TlsError;
use crate::transport::Identity;
Expand All @@ -203,17 +209,17 @@ mod rustls_keys {
mut cursor: std::io::Cursor<&[u8]>,
) -> Result<PrivateKey, crate::Error> {
// First attempt to load the private key assuming it is PKCS8-encoded
if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
if !keys.is_empty() {
return Ok(keys.remove(0));
if let Ok(mut keys) = rustls_pemfile::pkcs8_private_keys(&mut cursor) {
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
}

// If it not, try loading the private key as an RSA key
cursor.set_position(0);
if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
if !keys.is_empty() {
return Ok(keys.remove(0));
if let Ok(mut keys) = rustls_pemfile::rsa_private_keys(&mut cursor) {
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
}

Expand All @@ -226,8 +232,8 @@ mod rustls_keys {
) -> Result<(Vec<Certificate>, PrivateKey), crate::Error> {
let cert = {
let mut cert = std::io::Cursor::new(&identity.cert.pem[..]);
match pemfile::certs(&mut cert) {
Ok(certs) => certs,
match rustls_pemfile::certs(&mut cert) {
Ok(certs) => certs.into_iter().map(Certificate).collect(),
Err(_) => return Err(Box::new(TlsError::CertificateParseError)),
}
};
Expand All @@ -244,4 +250,15 @@ mod rustls_keys {

Ok((cert, key))
}

pub(crate) fn add_certs_from_pem(
mut certs: Cursor<&[u8]>,
roots: &mut RootCertStore,
) -> Result<(), crate::Error> {
let (_, ignored) = roots.add_parsable_certificates(&rustls_pemfile::certs(&mut certs)?);
match ignored == 0 {
true => Ok(()),
false => Err(Box::new(TlsError::CertificateParseError)),
}
}
}

0 comments on commit 6d17e58

Please sign in to comment.