Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RuntimeProvider in DoT implementations #1373

Merged
merged 5 commits into from Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions bin/tests/named_openssl_tests.rs
Expand Up @@ -22,12 +22,14 @@ use std::io::*;
use std::net::*;

use native_tls::Certificate;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_native_tls::TlsClientStreamBuilder;

use server_harness::{named_test_harness, query_a};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;

#[test]
fn test_example_tls_toml_startup() {
Expand Down Expand Up @@ -59,7 +61,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand All @@ -74,7 +77,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand Down
15 changes: 12 additions & 3 deletions bin/tests/named_rustls_tests.rs
Expand Up @@ -21,9 +21,11 @@ use std::sync::Arc;

use rustls::Certificate;
use rustls::ClientConfig;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_rustls::tls_client_connect;

use server_harness::{named_test_harness, query_a};
Expand Down Expand Up @@ -57,8 +59,11 @@ fn test_example_tls_toml_startup() {
config.root_store.add(&cert).expect("bad certificate");
let config = Arc::new(config);

let (stream, sender) =
tls_client_connect(addr, "ns.example.com".to_string(), config.clone());
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config.clone(),
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand All @@ -72,7 +77,11 @@ fn test_example_tls_toml_startup() {
.unwrap()
.next()
.unwrap();
let (stream, sender) = tls_client_connect(addr, "ns.example.com".to_string(), config);
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config,
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand Down
4 changes: 3 additions & 1 deletion crates/native-tls/src/tests.rs
Expand Up @@ -26,8 +26,10 @@ use std::{thread, time};
use futures_util::stream::StreamExt;
use native_tls;
use native_tls::{Certificate, TlsAcceptor};
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::xfer::SerialMessage;

#[allow(clippy::useless_attribute)]
Expand Down Expand Up @@ -193,7 +195,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let trust_chain = Certificate::from_der(&root_cert_der).unwrap();

// barrier.wait();
let mut builder = TlsStreamBuilder::new();
let mut builder = TlsStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
builder.add_ca(trust_chain);

// fix MTLS
Expand Down
17 changes: 9 additions & 8 deletions crates/native-tls/src/tls_client_stream.rs
Expand Up @@ -15,27 +15,28 @@ use futures_util::TryFutureExt;
use native_tls::Certificate;
#[cfg(feature = "mtls")]
use native_tls::Pkcs12;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::TlsStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use crate::TlsStreamBuilder;

/// TlsClientStream secure DNS over TCP stream
///
/// See TlsClientStreamBuilder::new()
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// Builder for TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder fo the construction of a TlsClientStream
pub fn new() -> TlsClientStreamBuilder {
pub fn new() -> TlsClientStreamBuilder<S> {
TlsClientStreamBuilder(TlsStreamBuilder::new())
}

Expand Down Expand Up @@ -64,7 +65,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -81,7 +82,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
32 changes: 17 additions & 15 deletions crates/native-tls/src/tls_stream.rs
Expand Up @@ -7,23 +7,23 @@

//! Base TlsStream

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::TryFutureExt;
use native_tls::Protocol::Tlsv12;
use native_tls::{Certificate, Identity, TlsConnector};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use trust_dns_proto::tcp::Connect;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver};

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsConnector> {
let mut builder = TlsConnector::builder();
Expand All @@ -47,10 +47,10 @@ fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsC
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_from_stream(
stream: TokioTlsStream<TokioTcpStream>,
pub fn tls_from_stream<S: Connect>(
stream: TokioTlsStream<AsyncIoStdAsTokio<S>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (TlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();

let stream = TcpStream::from_stream_with_receiver(
Expand All @@ -64,17 +64,19 @@ pub fn tls_from_stream(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<Certificate>,
identity: Option<Identity>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// Constructs a new TlsStreamBuilder
pub fn new() -> TlsStreamBuilder {
pub fn new() -> TlsStreamBuilder<S> {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -123,7 +125,7 @@ impl TlsStreamBuilder {
dns_name: String,
) -> (
// TODO: change to impl?
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand All @@ -137,17 +139,17 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
outbound_messages: StreamReceiver,
) -> Result<TlsStream, io::Error> {
) -> Result<TlsStream<S>, io::Error> {
use crate::tls_stream;

let ca_chain = self.ca_chain.clone();
let identity = self.identity;

let tcp_stream = tcp::tokio::connect(&name_server).await;
let tcp_stream = S::connect(name_server).await;

// TODO: for some reason the above wouldn't accept a ?
let tcp_stream = match tcp_stream {
Ok(tcp_stream) => tcp_stream,
Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream),
Err(err) => return Err(err),
};

Expand Down
15 changes: 8 additions & 7 deletions crates/openssl/src/tls_client_stream.rs
Expand Up @@ -14,23 +14,24 @@ use futures_util::TryFutureExt;
#[cfg(feature = "mtls")]
use openssl::pkcs12::Pkcs12;
use openssl::x509::X509;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::SslStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use super::TlsStreamBuilder;

/// A Type definition for the TLS stream
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// A Builder for the TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder for the construction of a TlsClientStream.
pub fn new() -> Self {
TlsClientStreamBuilder(TlsStreamBuilder::new())
Expand Down Expand Up @@ -71,7 +72,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -88,7 +89,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
33 changes: 18 additions & 15 deletions crates/openssl/src/tls_stream.rs
Expand Up @@ -5,10 +5,10 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::{future, TryFutureExt};
use openssl::pkcs12::ParsedPkcs12;
Expand All @@ -17,11 +17,11 @@ use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMet
use openssl::stack::Stack;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::{X509Ref, X509};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::{self, SslStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use trust_dns_proto::tcp::Connect;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::BufStreamHandle;

pub trait TlsIdentityExt {
Expand Down Expand Up @@ -57,7 +57,8 @@ impl TlsIdentityExt for SslContextBuilder {
}

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
pub type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;

fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnector> {
let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| {
Expand Down Expand Up @@ -115,29 +116,29 @@ fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnecto
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_stream_from_existing_tls_stream(
stream: AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>,
pub fn tls_stream_from_existing_tls_stream<S: Connect>(
stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (CompatTlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}

async fn connect_tls(
async fn connect_tls<S: Connect>(
tls_config: ConnectConfiguration,
dns_name: String,
name_server: SocketAddr,
) -> Result<TokioTlsStream<TokioTcpStream>, io::Error> {
let tcp = tcp::tokio::connect(&name_server).await.map_err(|e| {
) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error> {
let tcp = S::connect(name_server).await.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
let mut stream = tls_config
.into_ssl(&dns_name)
.and_then(|ssl| TokioTlsStream::new(ssl, tcp))
.and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {}", e)))?;
Pin::new(&mut stream).connect().await.map_err(|e| {
io::Error::new(
Expand All @@ -150,17 +151,19 @@ async fn connect_tls(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<X509>,
identity: Option<ParsedPkcs12>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// A builder for associating trust information to the `TlsStream`.
pub fn new() -> Self {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -208,7 +211,7 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand Down