Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RuntimeProvider in DoT implementations #1373

Merged
merged 5 commits into from Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions bin/tests/named_rustls_tests.rs
Expand Up @@ -21,9 +21,11 @@ use std::sync::Arc;

use rustls::Certificate;
use rustls::ClientConfig;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_rustls::tls_client_connect;

use server_harness::{named_test_harness, query_a};
Expand Down Expand Up @@ -57,8 +59,11 @@ fn test_example_tls_toml_startup() {
config.root_store.add(&cert).expect("bad certificate");
let config = Arc::new(config);

let (stream, sender) =
tls_client_connect(addr, "ns.example.com".to_string(), config.clone());
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config.clone(),
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand All @@ -72,7 +77,11 @@ fn test_example_tls_toml_startup() {
.unwrap()
.next()
.unwrap();
let (stream, sender) = tls_client_connect(addr, "ns.example.com".to_string(), config);
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config,
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand Down
22 changes: 12 additions & 10 deletions crates/resolver/src/name_server/connection_provider.rs
Expand Up @@ -28,7 +28,10 @@ use proto;
use proto::error::ProtoError;

#[cfg(feature = "tokio-runtime")]
use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
use proto::{
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
TokioTime,
};

#[cfg(feature = "mdns")]
use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
Expand Down Expand Up @@ -153,7 +156,7 @@ where

#[cfg(feature = "dns-over-rustls")]
let (stream, handle) =
{ crate::tls::new_tls_stream(socket_addr, tls_dns_name, client_config) };
{ crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name, client_config) };
#[cfg(not(feature = "dns-over-rustls"))]
let (stream, handle) = { crate::tls::new_tls_stream(socket_addr, tls_dns_name) };

Expand Down Expand Up @@ -205,6 +208,10 @@ where
}
}

#[cfg(feature = "dns-over-tls")]
/// Predefined type for TLS client stream
type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// The variants of all supported connections for the Resolver
#[allow(clippy::large_enum_variant, clippy::type_complexity)]
pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
Expand All @@ -228,22 +235,17 @@ pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
Box<
dyn Future<
Output = Result<
TcpClientStream<
AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>,
>,
TlsClientStream<<R as RuntimeProvider>::Tcp>,
ProtoError,
>,
> + Send
+ 'static,
>,
>,
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>,
NoopMessageFinalizer,
>,
DnsMultiplexer<
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>,
TlsClientStream<<R as RuntimeProvider>::Tcp>,
NoopMessageFinalizer,
>,
DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
TokioTime,
>,
),
Expand Down
5 changes: 3 additions & 2 deletions crates/resolver/src/tls/dns_over_rustls.rs
Expand Up @@ -19,6 +19,7 @@ use proto::error::ProtoError;
use proto::BufDnsStreamHandle;
use trust_dns_rustls::{tls_client_connect, TlsClientStream};

use crate::name_server::RuntimeProvider;
use crate::config::TlsClientConfig;

const ALPN_H2: &[u8] = b"h2";
Expand All @@ -40,12 +41,12 @@ lazy_static! {
}

#[allow(clippy::type_complexity)]
pub(crate) fn new_tls_stream(
pub(crate) fn new_tls_stream<R: RuntimeProvider>(
socket_addr: SocketAddr,
dns_name: String,
client_config: Option<TlsClientConfig>,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<R::Tcp>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let client_config = client_config.map_or_else(
Expand Down
10 changes: 7 additions & 3 deletions crates/rustls/src/tests.rs
Expand Up @@ -26,9 +26,9 @@ use openssl::x509::*;
use futures_util::stream::StreamExt;
use rustls::Certificate;
use rustls::ClientConfig;
use tokio::runtime::Runtime;
use tokio::{net::TcpStream as TokioTcpStream, runtime::Runtime};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this on a separate line, in keeping with the surrounding style.


use trust_dns_proto::xfer::SerialMessage;
use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, xfer::SerialMessage};

use crate::tls_connect;

Expand Down Expand Up @@ -214,7 +214,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
// config_mtls(&root_pkey, &root_name, &root_cert, &mut builder);
// }

let (stream, mut sender) = tls_connect(server_addr, dns_name.to_string(), Arc::new(config));
let (stream, mut sender) = tls_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
server_addr,
dns_name.to_string(),
Arc::new(config),
);

// TODO: there is a race failure here... a race with the server thread most likely...
let mut stream = io_loop.block_on(stream).expect("run failed to get stream");
Expand Down
11 changes: 5 additions & 6 deletions crates/rustls/src/tls_client_stream.rs
Expand Up @@ -14,18 +14,17 @@ use std::sync::Arc;

use futures_util::TryFutureExt;
use rustls::ClientConfig;
use tokio::net::TcpStream as TokioTcpStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::xfer::BufDnsStreamHandle;

use crate::tls_stream::tls_connect;

/// Type of TlsClientStream used with Rustls
pub type TlsClientStream =
TcpClientStream<AsyncIoTokioAsStd<tokio_rustls::client::TlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>>>;

/// Creates a new TlsStream to the specified name_server
///
Expand All @@ -34,12 +33,12 @@ pub type TlsClientStream =
/// * `name_server` - IP and Port for the remote DNS resolver
/// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
#[allow(clippy::type_complexity)]
pub fn tls_client_connect(
pub fn tls_client_connect<S: Connect>(
name_server: SocketAddr,
dns_name: String,
client_config: Arc<ClientConfig>,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send + Unpin>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send + Unpin>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = tls_connect(name_server, dns_name, client_config);
Expand Down
24 changes: 15 additions & 9 deletions crates/rustls/src/tls_stream.rs
Expand Up @@ -20,12 +20,15 @@ use tokio::net::TcpStream as TokioTcpStream;
use tokio_rustls::TlsConnector;
use webpki::{DNSName, DNSNameRef};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, DnsTcpStream, TcpStream};
use trust_dns_proto::tcp::{DnsTcpStream, TcpStream};
use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver};
use trust_dns_proto::{
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
tcp::Connect,
};

/// Predefined type for abstracting the TlsClientStream with TokioTls
pub type TokioTlsClientStream = tokio_rustls::client::TlsStream<TokioTcpStream>;
pub type TokioTlsClientStream<S> = tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>;

/// Predefined type for abstracting the TlsServerStream with TokioTls
pub type TokioTlsServerStream = tokio_rustls::server::TlsStream<TokioTcpStream>;
Expand Down Expand Up @@ -71,15 +74,18 @@ pub fn tls_from_stream<S: DnsTcpStream>(
/// * `name_server` - IP and Port for the remote DNS resolver
/// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate
#[allow(clippy::type_complexity)]
pub fn tls_connect(
pub fn tls_connect<S: Connect>(
name_server: SocketAddr,
dns_name: String,
client_config: Arc<ClientConfig>,
) -> (
Pin<
Box<
dyn Future<
Output = Result<TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream>>, io::Error>,
Output = Result<
TlsStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>,
io::Error,
>,
> + Send,
>,
>,
Expand All @@ -101,20 +107,20 @@ pub fn tls_connect(
(stream, message_sender)
}

async fn connect_tls(
async fn connect_tls<S: Connect>(
tls_connector: TlsConnector,
name_server: SocketAddr,
dns_name: String,
outbound_messages: StreamReceiver,
) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream>>> {
let tcp = tcp::tokio::connect(&name_server).await?;
) -> io::Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>> {
let tcp = S::connect(name_server).await?;

let dns_name = DNSNameRef::try_from_ascii_str(&dns_name)
.map(DNSName::from)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name"))?;

let s = tls_connector
.connect(dns_name.as_ref(), tcp)
.connect(dns_name.as_ref(), AsyncIoStdAsTokio(tcp))
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
Expand Down
18 changes: 10 additions & 8 deletions tests/integration-tests/src/tls_client_connection.rs
Expand Up @@ -8,31 +8,32 @@
//! TLS based DNS client connection for Client impls
//! TODO: This modules was moved from trust-dns-rustls, it really doesn't need to exist if tests are refactored...

use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::{marker::PhantomData, net::SocketAddr};

use futures::Future;

use trust_dns_client::client::ClientConnection;
use trust_dns_client::rr::dnssec::Signer;
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect};
use trust_dns_proto::{error::ProtoError, tcp::Connect};

use rustls::ClientConfig;
use trust_dns_rustls::{tls_client_connect, TlsClientStream};

/// Tls client connection
///
/// Use with `trust_dns_client::client::Client` impls
pub struct TlsClientConnection {
pub struct TlsClientConnection<T> {
name_server: SocketAddr,
dns_name: String,
client_config: Arc<ClientConfig>,
marker: PhantomData<T>,
}

#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
impl TlsClientConnection {
impl<T> TlsClientConnection<T> {
pub fn new(
name_server: SocketAddr,
dns_name: String,
Expand All @@ -42,16 +43,17 @@ impl TlsClientConnection {
name_server,
dns_name,
client_config,
marker: PhantomData,
}
}
}

#[allow(clippy::type_complexity)]
impl ClientConnection for TlsClientConnection {
type Sender = DnsMultiplexer<TlsClientStream, Signer>;
impl<T: Connect> ClientConnection for TlsClientConnection<T> {
type Sender = DnsMultiplexer<TlsClientStream<T>, Signer>;
type SenderFuture = DnsMultiplexerConnect<
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
TlsClientStream,
Pin<Box<dyn Future<Output = Result<TlsClientStream<T>, ProtoError>> + Send>>,
TlsClientStream<T>,
Signer,
>;

Expand Down