Skip to content

Commit

Permalink
Separate default rustls::ClientConfig for each protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpedda committed Sep 21, 2023
1 parent 90edda6 commit 2e35a45
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 111 deletions.
15 changes: 15 additions & 0 deletions crates/resolver/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,15 @@ pub struct ResolverOpts {
pub authentic_data: bool,
/// Shuffle DNS servers before each query.
pub shuffle_dns_servers: bool,
#[cfg(feature = "dns-over-rustls")]
#[cfg_attr(feature = "serde-config", serde(skip))]
pub(crate) tls_client_config: Option<TlsClientConfig>,
#[cfg(feature = "dns-over-https-rustls")]
#[cfg_attr(feature = "serde-config", serde(skip))]
pub(crate) https_client_config: Option<TlsClientConfig>,
#[cfg(all(feature = "dns-over-quic", feature = "dns-over-rustls"))]
#[cfg_attr(feature = "serde-config", serde(skip))]
pub(crate) quic_client_config: Option<TlsClientConfig>,
}

impl Default for ResolverOpts {
Expand Down Expand Up @@ -955,6 +964,12 @@ impl Default for ResolverOpts {
recursion_desired: true,
authentic_data: false,
shuffle_dns_servers: false,
#[cfg(feature = "dns-over-rustls")]
tls_client_config: None,
#[cfg(feature = "dns-over-https-rustls")]
https_client_config: None,
#[cfg(feature = "dns-over-quic")]
quic_client_config: None,
}
}
}
Expand Down
45 changes: 21 additions & 24 deletions crates/resolver/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,42 @@
use std::future::Future;
use std::net::SocketAddr;

use crate::tls::CLIENT_CONFIG;

use proto::https::{HttpsClientConnect, HttpsClientStream, HttpsClientStreamBuilder};
use proto::tcp::{Connect, DnsTcpStream};
use proto::xfer::{DnsExchange, DnsExchangeConnect};
use proto::TokioTime;
use rustls::ClientConfig;
use trust_dns_proto::error::ProtoError;

use crate::config::TlsClientConfig;

const ALPN_H2: &[u8] = b"h2";

pub(crate) fn http_client_config() -> Result<ClientConfig, ProtoError> {
let mut client_config = ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(crate::tls::root_store()?)
.with_no_client_auth();

client_config.alpn_protocols.push(ALPN_H2.to_vec());
Ok(client_config)
}

#[allow(clippy::type_complexity)]
#[allow(unused)]
pub(crate) fn new_https_stream<S>(
socket_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> DnsExchangeConnect<HttpsClientConnect<S>, HttpsClientStream, TokioTime>
where
S: Connect,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
match CLIENT_CONFIG.clone() {
Ok(client_config) => client_config,
Err(error) => return DnsExchange::error(error),
}
};

let mut https_builder = HttpsClientStreamBuilder::with_client_config(client_config);
let mut https_builder = HttpsClientStreamBuilder::with_client_config(client_config.0);
if let Some(bind_addr) = bind_addr {
https_builder.bind_addr(bind_addr);
}
Expand All @@ -49,24 +55,15 @@ pub(crate) fn new_https_stream_with_future<S, F>(
future: F,
socket_addr: SocketAddr,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> DnsExchangeConnect<HttpsClientConnect<S>, HttpsClientStream, TokioTime>
where
S: DnsTcpStream,
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
match CLIENT_CONFIG.clone() {
Ok(client_config) => client_config,
Err(error) => return DnsExchange::error(error),
}
};

DnsExchange::connect(HttpsClientStreamBuilder::build_with_future(
future,
client_config,
client_config.0,
socket_addr,
dns_name,
))
Expand Down
95 changes: 66 additions & 29 deletions crates/resolver/src/name_server/connection_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,34 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
let tcp_future = self.runtime_provider.connect_tcp(socket_addr);

#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();

#[cfg(feature = "dns-over-rustls")]
let (stream, handle) = {
crate::tls::new_tls_stream_with_future(
let (stream, handle) = 'stream: {
#[cfg(feature = "dns-over-rustls")]
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match crate::tls::tls_client_config() {
Ok(client_config) => {
crate::config::TlsClientConfig(Arc::new(client_config))
}
Err(err) => {
break 'stream (
Box::pin(std::future::ready(Err(err)))
as Pin<Box<dyn Future<Output = _> + Send>>,
proto::BufDnsStreamHandle::new(socket_addr).0,
)
}
}
};

#[cfg(feature = "dns-over-rustls")]
break 'stream crate::tls::new_tls_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
client_config,
)
};
#[cfg(not(feature = "dns-over-rustls"))]
let (stream, handle) = {
);

#[cfg(not(feature = "dns-over-rustls"))]
crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
};

Expand All @@ -340,16 +354,28 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
Protocol::Https => {
let socket_addr = config.socket_addr;
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let tcp_future = self.runtime_provider.connect_tcp(socket_addr);

let exchange = crate::https::new_https_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
client_config,
);
let exchange = 'exchange: {
#[cfg(feature = "dns-over-rustls")]
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match crate::https::http_client_config() {
Ok(client_config) => {
crate::config::TlsClientConfig(Arc::new(client_config))
}
Err(err) => break 'exchange DnsExchange::error(err),
}
};
let tcp_future = self.runtime_provider.connect_tcp(socket_addr);

crate::https::new_https_stream_with_future(
tcp_future,
socket_addr,
tls_dns_name,
client_config,
)
};
ConnectionConnect::Https(exchange)
}
#[cfg(feature = "dns-over-quic")]
Expand All @@ -362,16 +388,27 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
}
});
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);

let exchange = crate::quic::new_quic_stream_with_future(
udp_future,
socket_addr,
tls_dns_name,
client_config,
);
let exchange = 'exchange: {
#[cfg(feature = "dns-over-rustls")]
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match crate::quic::quic_client_config() {
Ok(client_config) => {
crate::config::TlsClientConfig(Arc::new(client_config))
}
Err(err) => break 'exchange DnsExchange::error(err),
}
};
let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);

crate::quic::new_quic_stream_with_future(
udp_future,
socket_addr,
tls_dns_name,
client_config,
)
};
ConnectionConnect::Quic(exchange)
}
#[cfg(feature = "mdns")]
Expand Down

0 comments on commit 2e35a45

Please sign in to comment.