Skip to content

Commit

Permalink
Enable RuntimeProvider in DoT implementations (#1373)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengyuhui committed Feb 3, 2021
1 parent 6b2ed70 commit 26b842b
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 94 deletions.
8 changes: 6 additions & 2 deletions bin/tests/named_openssl_tests.rs
Expand Up @@ -22,12 +22,14 @@ use std::io::*;
use std::net::*;

use native_tls::Certificate;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_native_tls::TlsClientStreamBuilder;

use server_harness::{named_test_harness, query_a};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;

#[test]
fn test_example_tls_toml_startup() {
Expand Down Expand Up @@ -59,7 +61,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand All @@ -74,7 +77,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand Down
15 changes: 12 additions & 3 deletions bin/tests/named_rustls_tests.rs
Expand Up @@ -21,9 +21,11 @@ use std::sync::Arc;

use rustls::Certificate;
use rustls::ClientConfig;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_rustls::tls_client_connect;

use server_harness::{named_test_harness, query_a};
Expand Down Expand Up @@ -57,8 +59,11 @@ fn test_example_tls_toml_startup() {
config.root_store.add(&cert).expect("bad certificate");
let config = Arc::new(config);

let (stream, sender) =
tls_client_connect(addr, "ns.example.com".to_string(), config.clone());
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config.clone(),
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand All @@ -72,7 +77,11 @@ fn test_example_tls_toml_startup() {
.unwrap()
.next()
.unwrap();
let (stream, sender) = tls_client_connect(addr, "ns.example.com".to_string(), config);
let (stream, sender) = tls_client_connect::<AsyncIoTokioAsStd<TokioTcpStream>>(
addr,
"ns.example.com".to_string(),
config,
);
let client = AsyncClient::new(stream, Box::new(sender), None);

let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect");
Expand Down
4 changes: 3 additions & 1 deletion crates/native-tls/src/tests.rs
Expand Up @@ -26,8 +26,10 @@ use std::{thread, time};
use futures_util::stream::StreamExt;
use native_tls;
use native_tls::{Certificate, TlsAcceptor};
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::xfer::SerialMessage;

#[allow(clippy::useless_attribute)]
Expand Down Expand Up @@ -193,7 +195,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let trust_chain = Certificate::from_der(&root_cert_der).unwrap();

// barrier.wait();
let mut builder = TlsStreamBuilder::new();
let mut builder = TlsStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
builder.add_ca(trust_chain);

// fix MTLS
Expand Down
17 changes: 9 additions & 8 deletions crates/native-tls/src/tls_client_stream.rs
Expand Up @@ -15,27 +15,28 @@ use futures_util::TryFutureExt;
use native_tls::Certificate;
#[cfg(feature = "mtls")]
use native_tls::Pkcs12;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::TlsStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use crate::TlsStreamBuilder;

/// TlsClientStream secure DNS over TCP stream
///
/// See TlsClientStreamBuilder::new()
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// Builder for TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder fo the construction of a TlsClientStream
pub fn new() -> TlsClientStreamBuilder {
pub fn new() -> TlsClientStreamBuilder<S> {
TlsClientStreamBuilder(TlsStreamBuilder::new())
}

Expand Down Expand Up @@ -64,7 +65,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -81,7 +82,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
32 changes: 17 additions & 15 deletions crates/native-tls/src/tls_stream.rs
Expand Up @@ -7,23 +7,23 @@

//! Base TlsStream

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::TryFutureExt;
use native_tls::Protocol::Tlsv12;
use native_tls::{Certificate, Identity, TlsConnector};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use trust_dns_proto::tcp::Connect;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver};

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsConnector> {
let mut builder = TlsConnector::builder();
Expand All @@ -47,10 +47,10 @@ fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsC
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_from_stream(
stream: TokioTlsStream<TokioTcpStream>,
pub fn tls_from_stream<S: Connect>(
stream: TokioTlsStream<AsyncIoStdAsTokio<S>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (TlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();

let stream = TcpStream::from_stream_with_receiver(
Expand All @@ -64,17 +64,19 @@ pub fn tls_from_stream(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<Certificate>,
identity: Option<Identity>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// Constructs a new TlsStreamBuilder
pub fn new() -> TlsStreamBuilder {
pub fn new() -> TlsStreamBuilder<S> {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -123,7 +125,7 @@ impl TlsStreamBuilder {
dns_name: String,
) -> (
// TODO: change to impl?
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand All @@ -137,17 +139,17 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
outbound_messages: StreamReceiver,
) -> Result<TlsStream, io::Error> {
) -> Result<TlsStream<S>, io::Error> {
use crate::tls_stream;

let ca_chain = self.ca_chain.clone();
let identity = self.identity;

let tcp_stream = tcp::tokio::connect(&name_server).await;
let tcp_stream = S::connect(name_server).await;

// TODO: for some reason the above wouldn't accept a ?
let tcp_stream = match tcp_stream {
Ok(tcp_stream) => tcp_stream,
Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream),
Err(err) => return Err(err),
};

Expand Down
15 changes: 8 additions & 7 deletions crates/openssl/src/tls_client_stream.rs
Expand Up @@ -14,23 +14,24 @@ use futures_util::TryFutureExt;
#[cfg(feature = "mtls")]
use openssl::pkcs12::Pkcs12;
use openssl::x509::X509;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::SslStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use super::TlsStreamBuilder;

/// A Type definition for the TLS stream
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// A Builder for the TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder for the construction of a TlsClientStream.
pub fn new() -> Self {
TlsClientStreamBuilder(TlsStreamBuilder::new())
Expand Down Expand Up @@ -71,7 +72,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -88,7 +89,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
33 changes: 18 additions & 15 deletions crates/openssl/src/tls_stream.rs
Expand Up @@ -5,10 +5,10 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::{future, TryFutureExt};
use openssl::pkcs12::ParsedPkcs12;
Expand All @@ -17,11 +17,11 @@ use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMet
use openssl::stack::Stack;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::{X509Ref, X509};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::{self, SslStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use trust_dns_proto::tcp::Connect;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::BufStreamHandle;

pub trait TlsIdentityExt {
Expand Down Expand Up @@ -57,7 +57,8 @@ impl TlsIdentityExt for SslContextBuilder {
}

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
pub type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;

fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnector> {
let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| {
Expand Down Expand Up @@ -115,29 +116,29 @@ fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnecto
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_stream_from_existing_tls_stream(
stream: AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>,
pub fn tls_stream_from_existing_tls_stream<S: Connect>(
stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (CompatTlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}

async fn connect_tls(
async fn connect_tls<S: Connect>(
tls_config: ConnectConfiguration,
dns_name: String,
name_server: SocketAddr,
) -> Result<TokioTlsStream<TokioTcpStream>, io::Error> {
let tcp = tcp::tokio::connect(&name_server).await.map_err(|e| {
) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error> {
let tcp = S::connect(name_server).await.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
let mut stream = tls_config
.into_ssl(&dns_name)
.and_then(|ssl| TokioTlsStream::new(ssl, tcp))
.and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {}", e)))?;
Pin::new(&mut stream).connect().await.map_err(|e| {
io::Error::new(
Expand All @@ -150,17 +151,19 @@ async fn connect_tls(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<X509>,
identity: Option<ParsedPkcs12>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// A builder for associating trust information to the `TlsStream`.
pub fn new() -> Self {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -208,7 +211,7 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand Down

0 comments on commit 26b842b

Please sign in to comment.