Skip to content

Commit

Permalink
Upgrade rustls to 0.20
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed Oct 26, 2021
1 parent d25ab07 commit 624f289
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 71 deletions.
57 changes: 32 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Expand Up @@ -151,6 +151,10 @@ serde = { version = "1.0.111", features = ["derive"] }
serde_json = "1.0.53"
url = "2.1.1"

[patch.crates-io]
# waiting for rustls 0.20.1
rustls = { git = "https://github.com/rustls/rustls.git" }

#
# Any
#
Expand Down
8 changes: 4 additions & 4 deletions sqlx-core/Cargo.toml
Expand Up @@ -93,7 +93,7 @@ _rt-actix = ["tokio-stream"]
_rt-async-std = []
_rt-tokio = ["tokio-stream"]
_tls-native-tls = []
_tls-rustls = ["rustls", "webpki", "webpki-roots"]
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]

# support offline/decoupled building (enables serialization of `Describe`)
offline = ["serde", "either/serde"]
Expand Down Expand Up @@ -144,7 +144,8 @@ parking_lot = "0.11.0"
rand = { version = "0.8.3", default-features = false, optional = true, features = ["std", "std_rng"] }
regex = { version = "1.3.9", optional = true }
rsa = { version = "0.4.0", optional = true }
rustls = { version = "0.19.0", features = ["dangerous_configuration"], optional = true }
rustls = { version = "0.20.0", features = ["dangerous_configuration"], optional = true }
rustls-pemfile = { version = "0.2.0", optional = true }
serde = { version = "1.0.106", features = ["derive", "rc"], optional = true }
serde_json = { version = "1.0.51", features = ["raw_value"], optional = true }
sha-1 = { version = "0.9.0", default-features = false, optional = true }
Expand All @@ -156,8 +157,7 @@ tokio-stream = { version = "0.1.2", features = ["fs"], optional = true }
smallvec = "1.4.0"
url = { version = "2.1.1", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true, features = ["std"] }
webpki = { version = "0.21.0", optional = true }
webpki-roots = { version = "0.21.0", optional = true }
webpki-roots = { version = "0.22.0", optional = true }
whoami = "1.0.1"
stringprep = "0.1.2"
bstr = { version = "0.2.14", default-features = false, features = ["std"], optional = true }
Expand Down
8 changes: 0 additions & 8 deletions sqlx-core/src/error.rs
Expand Up @@ -253,14 +253,6 @@ impl From<sqlx_rt::native_tls::Error> for Error {
}
}

#[cfg(feature = "_tls-rustls")]
impl From<webpki::InvalidDNSNameError> for Error {
#[inline]
fn from(error: webpki::InvalidDNSNameError) -> Self {
Error::Tls(Box::new(error))
}
}

// Format an error message as a `Protocol` error
macro_rules! err_protocol {
($expr:expr) => {
Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/tls/mod.rs
@@ -1,5 +1,6 @@
#![allow(dead_code)]

use std::convert::TryFrom;
use std::io;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
Expand Down Expand Up @@ -104,7 +105,7 @@ where
};

#[cfg(feature = "_tls-rustls")]
let host = webpki::DNSNameRef::try_from_ascii_str(host)?;
let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?;

*self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);

Expand Down
87 changes: 58 additions & 29 deletions sqlx-core/src/net/tls/rustls.rs
@@ -1,11 +1,11 @@
use crate::net::CertificateInput;
use rustls::{
Certificate, ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError,
WebPKIVerifier,
client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier},
ClientConfig, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName,
};
use std::io::Cursor;
use std::sync::Arc;
use webpki::DNSNameRef;
use std::time::SystemTime;

use crate::error::Error;

Expand All @@ -14,32 +14,47 @@ pub async fn configure_tls_connector(
accept_invalid_hostnames: bool,
root_cert_path: Option<&CertificateInput>,
) -> Result<sqlx_rt::TlsConnector, Error> {
let mut config = ClientConfig::new();
let config = ClientConfig::builder().with_safe_defaults();

if accept_invalid_certs {
let config = if accept_invalid_certs {
config
.dangerous()
.set_certificate_verifier(Arc::new(DummyTlsVerifier));
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
.with_no_client_auth()
} else {
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));

if let Some(ca) = root_cert_path {
let data = ca.data().await?;
let mut cursor = Cursor::new(data);
config
.root_store
.add_pem_file(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?;

for cert in rustls_pemfile::certs(&mut cursor)
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
{
cert_store
.add(&rustls::Certificate(cert))
.map_err(|err| Error::Tls(err.into()))?;
}
}

if accept_invalid_hostnames {
let verifier = WebPkiVerifier::new(cert_store, None);

config
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_no_client_auth()
} else {
config
.dangerous()
.set_certificate_verifier(Arc::new(NoHostnameTlsVerifier));
.with_root_certificates(cert_store)
.with_no_client_auth()
}
}
};

Ok(Arc::new(config).into())
}
Expand All @@ -49,28 +64,42 @@ struct DummyTlsVerifier;
impl ServerCertVerifier for DummyTlsVerifier {
fn verify_server_cert(
&self,
_roots: &RootCertStore,
_presented_certs: &[Certificate],
_dns_name: DNSNameRef<'_>,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
_now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
}

pub struct NoHostnameTlsVerifier;
pub struct NoHostnameTlsVerifier {
verifier: WebPkiVerifier,
}

impl ServerCertVerifier for NoHostnameTlsVerifier {
fn verify_server_cert(
&self,
roots: &RootCertStore,
presented_certs: &[Certificate],
dns_name: DNSNameRef<'_>,
end_entity: &rustls::Certificate,
intermediates: &[rustls::Certificate],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
) -> Result<ServerCertVerified, TLSError> {
let verifier = WebPKIVerifier::new();
match verifier.verify_server_cert(roots, presented_certs, dns_name, ocsp_response) {
Err(TLSError::WebPKIError(webpki::Error::CertNotValidForName)) => {
now: SystemTime,
) -> Result<ServerCertVerified, TlsError> {
match self.verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
scts,
ocsp_response,
now,
) {
Err(TlsError::InvalidCertificateData(reason))
if reason.contains("CertNotValidForName") =>
{
Ok(ServerCertVerified::assertion())
}
res => res,
Expand Down
6 changes: 3 additions & 3 deletions sqlx-rt/Cargo.toml
Expand Up @@ -20,7 +20,7 @@ runtime-async-std-native-tls = [
runtime-tokio-native-tls = ["_rt-tokio", "_tls-native-tls", "tokio-native-tls"]

runtime-actix-rustls = ["_rt-actix", "_tls-rustls", "tokio-rustls"]
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "async-rustls"]
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "futures-rustls"]
runtime-tokio-rustls = ["_rt-tokio", "_tls-rustls", "tokio-rustls"]

# Not used directly and not re-exported from sqlx
Expand All @@ -32,11 +32,11 @@ _tls-rustls = []

[dependencies]
async-native-tls = { version = "0.3.3", optional = true }
async-rustls = { version = "0.2.0", optional = true }
futures-rustls = { version = "0.22.0", optional = true }
actix-rt = { version = "2.0.0", default-features = false, optional = true }
async-std = { version = "1.7.0", features = ["unstable"], optional = true }
tokio-native-tls = { version = "0.3.0", optional = true }
tokio-rustls = { version = "0.22.0", optional = true }
tokio-rustls = { version = "0.23.0", optional = true }
native-tls = { version = "0.2.4", optional = true }
once_cell = { version = "1.4", features = ["std"], optional = true }

Expand Down
2 changes: 1 addition & 1 deletion sqlx-rt/src/lib.rs
Expand Up @@ -193,4 +193,4 @@ pub use async_native_tls::{TlsConnector, TlsStream};
feature = "_rt-actix"
)),
))]
pub use async_rustls::{client::TlsStream, TlsConnector};
pub use futures_rustls::{client::TlsStream, TlsConnector};

0 comments on commit 624f289

Please sign in to comment.