From 6d17e586c3ae2c99eba3b610718cadf61210935b Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 6 Dec 2021 11:05:56 +0100 Subject: [PATCH] feat(tls): upgrade to tokio-rustls 0.23 (rustls 0.20) --- tonic/Cargo.toml | 9 +- tonic/src/transport/server/conn.rs | 6 +- tonic/src/transport/service/connector.rs | 22 ++--- tonic/src/transport/service/tls.rs | 119 +++++++++++++---------- 4 files changed, 86 insertions(+), 70 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 957342666..9c28dd036 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -27,7 +27,7 @@ codegen = ["async-trait"] compression = ["flate2"] default = ["transport", "codegen", "prost"] prost = ["prost1", "prost-derive"] -tls = ["transport", "tokio-rustls"] +tls = ["rustls-pemfile", "transport", "tokio-rustls"] tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-roots-common = ["tls"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] @@ -79,9 +79,10 @@ tower = {version = "0.4.7", features = ["balance", "buffer", "discover", "limit" tracing-futures = {version = "0.2", optional = true} # rustls -rustls-native-certs = {version = "0.5", optional = true} -tokio-rustls = {version = "0.22", optional = true} -webpki-roots = {version = "0.21.1", optional = true} +rustls-pemfile = { version = "0.2.1", optional = true } +rustls-native-certs = { version = "0.6.1", optional = true } +tokio-rustls = { version = "0.23.1", optional = true } +webpki-roots = { version = "0.22.1", optional = true } # compression flate2 = {version = "1.0", optional = true} diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 40d60b232..53bd47c31 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -7,7 +7,7 @@ use crate::transport::Certificate; #[cfg(feature = "tls")] use std::sync::Arc; #[cfg(feature = "tls")] -use tokio_rustls::{rustls::Session, server::TlsStream}; +use tokio_rustls::server::TlsStream; /// Trait that connected IO resources implement and use to produce info about the connection. /// @@ -115,10 +115,10 @@ where let (inner, session) = self.get_ref(); let inner = inner.connect_info(); - let certs = if let Some(certs) = session.get_peer_certificates() { + let certs = if let Some(certs) = session.peer_certificates() { let certs = certs .into_iter() - .map(|c| Certificate::from_pem(c.0)) + .map(|c| Certificate::from_pem(c)) .collect(); Some(Arc::new(certs)) } else { diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index c4d216b83..d2a35a973 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,6 +3,8 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; +#[cfg(feature = "tls-roots-common")] +use std::convert::TryInto; use std::task::{Context, Poll}; use tower::make::MakeConnection; use tower_service::Service; @@ -39,22 +41,18 @@ impl Connector { #[cfg(feature = "tls-roots-common")] fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option { - use tokio_rustls::webpki::DNSNameRef; - if self.tls.is_some() { return self.tls.clone(); } - match (scheme, host) { - (Some("https"), Some(host)) => { - if DNSNameRef::try_from_ascii(host.as_bytes()).is_ok() { - TlsConnector::new_with_rustls_cert(None, None, host.to_owned()).ok() - } else { - None - } - } - _ => None, - } + let host = match (scheme, host) { + (Some("https"), Some(host)) => host, + _ => return None, + }; + + host.try_into() + .ok() + .and_then(|dns| TlsConnector::new_with_rustls_cert(None, None, dns).ok()) } } diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 5bd960c48..9f5c5102f 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -5,12 +5,13 @@ use crate::transport::{ }; #[cfg(feature = "tls-roots")] use rustls_native_certs; +#[cfg(feature = "tls")] +use std::convert::TryInto; use std::{fmt, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(feature = "tls")] use tokio_rustls::{ - rustls::{ClientConfig, NoClientAuth, ServerConfig, Session}, - webpki::DNSNameRef, + rustls::{ClientConfig, RootCertStore, ServerConfig, ServerName}, TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector, }; @@ -31,7 +32,7 @@ enum TlsError { #[derive(Clone)] pub(crate) struct TlsConnector { config: Arc, - domain: Arc, + domain: Arc, } impl TlsConnector { @@ -41,38 +42,47 @@ impl TlsConnector { identity: Option, domain: String, ) -> Result { - let mut config = ClientConfig::new(); - config.set_protocols(&[Vec::from(ALPN_H2)]); - - if let Some(identity) = identity { - let (client_cert, client_key) = rustls_keys::load_identity(identity)?; - config.set_single_client_cert(client_cert, client_key)?; - } + let builder = ClientConfig::builder().with_safe_defaults(); + let mut roots = RootCertStore::empty(); #[cfg(feature = "tls-roots")] { - config.root_store = match rustls_native_certs::load_native_certs() { - Ok(store) | Err((Some(store), _)) => store, - Err((None, error)) => return Err(error.into()), + match rustls_native_certs::load_native_certs() { + Ok(certs) => roots.add_parsable_certificates( + &certs.into_iter().map(|cert| cert.0).collect::>(), + ), + Err(error) => return Err(error.into()), }; } #[cfg(feature = "tls-webpki-roots")] { - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + use tokio_rustls::rustls::OwnedTrustAnchor; + + roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); } if let Some(cert) = ca_cert { - let mut buf = std::io::Cursor::new(&cert.pem[..]); - config.root_store.add_pem_file(&mut buf).unwrap(); + rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?; } - Ok(Self { - config: Arc::new(config), - domain: Arc::new(domain), - }) + let builder = builder.with_root_certificates(roots); + let mut config = match identity { + Some(identity) => { + let (client_cert, client_key) = rustls_keys::load_identity(identity)?; + builder.with_single_cert(client_cert, client_key)? + } + None => builder.with_no_client_auth(), + }; + + config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec()); + Self::new_with_rustls_raw(config, domain) } #[cfg(feature = "tls")] @@ -82,7 +92,7 @@ impl TlsConnector { ) -> Result { Ok(Self { config: Arc::new(config), - domain: Arc::new(domain), + domain: Arc::new(domain.as_str().try_into()?), }) } @@ -91,15 +101,13 @@ impl TlsConnector { I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let tls_io = { - let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned(); - let io = RustlsConnector::from(self.config.clone()) - .connect(dns.as_ref(), io) + .connect(self.domain.as_ref().to_owned(), io) .await?; let (_, session) = io.get_ref(); - match session.get_alpn_protocol() { + match session.alpn_protocol() { Some(b) if b == b"h2" => (), _ => return Err(TlsError::H2NotNegotiated.into()), }; @@ -128,26 +136,22 @@ impl TlsAcceptor { identity: Identity, client_ca_root: Option, ) -> Result { - let (cert, key) = rustls_keys::load_identity(identity)?; + let builder = ServerConfig::builder().with_safe_defaults(); - let mut config = match client_ca_root { - None => ServerConfig::new(NoClientAuth::new()), + let builder = match client_ca_root { + None => builder.with_no_client_auth(), Some(cert) => { - let mut cert = std::io::Cursor::new(&cert.pem[..]); - - let mut client_root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); - if client_root_cert_store.add_pem_file(&mut cert).is_err() { - return Err(Box::new(TlsError::CertificateParseError)); - } - - let client_auth = - tokio_rustls::rustls::AllowAnyAuthenticatedClient::new(client_root_cert_store); - ServerConfig::new(client_auth) + use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; + let mut roots = RootCertStore::empty(); + rustls_keys::add_certs_from_pem(std::io::Cursor::new(&cert.pem[..]), &mut roots)?; + builder.with_client_cert_verifier(AllowAnyAuthenticatedClient::new(roots)) } }; - config.set_single_cert(cert, key)?; - config.set_protocols(&[Vec::from(ALPN_H2)]); + let (cert, key) = rustls_keys::load_identity(identity)?; + let mut config = builder.with_single_cert(cert, key)?; + + config.alpn_protocols.push(ALPN_H2.as_bytes().to_vec()); Ok(Self { inner: Arc::new(config), }) @@ -194,7 +198,9 @@ impl std::error::Error for TlsError {} #[cfg(feature = "tls")] mod rustls_keys { - use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey}; + use std::io::Cursor; + + use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore}; use crate::transport::service::tls::TlsError; use crate::transport::Identity; @@ -203,17 +209,17 @@ mod rustls_keys { mut cursor: std::io::Cursor<&[u8]>, ) -> Result { // First attempt to load the private key assuming it is PKCS8-encoded - if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) { - if !keys.is_empty() { - return Ok(keys.remove(0)); + if let Ok(mut keys) = rustls_pemfile::pkcs8_private_keys(&mut cursor) { + if let Some(key) = keys.pop() { + return Ok(PrivateKey(key)); } } // If it not, try loading the private key as an RSA key cursor.set_position(0); - if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) { - if !keys.is_empty() { - return Ok(keys.remove(0)); + if let Ok(mut keys) = rustls_pemfile::rsa_private_keys(&mut cursor) { + if let Some(key) = keys.pop() { + return Ok(PrivateKey(key)); } } @@ -226,8 +232,8 @@ mod rustls_keys { ) -> Result<(Vec, PrivateKey), crate::Error> { let cert = { let mut cert = std::io::Cursor::new(&identity.cert.pem[..]); - match pemfile::certs(&mut cert) { - Ok(certs) => certs, + match rustls_pemfile::certs(&mut cert) { + Ok(certs) => certs.into_iter().map(Certificate).collect(), Err(_) => return Err(Box::new(TlsError::CertificateParseError)), } }; @@ -244,4 +250,15 @@ mod rustls_keys { Ok((cert, key)) } + + pub(crate) fn add_certs_from_pem( + mut certs: Cursor<&[u8]>, + roots: &mut RootCertStore, + ) -> Result<(), crate::Error> { + let (_, ignored) = roots.add_parsable_certificates(&rustls_pemfile::certs(&mut certs)?); + match ignored == 0 { + true => Ok(()), + false => Err(Box::new(TlsError::CertificateParseError)), + } + } }