Skip to content

Commit

Permalink
Add to native TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
kmod-midori committed Feb 3, 2021
1 parent e84e67c commit 7be8e7c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 28 deletions.
5 changes: 3 additions & 2 deletions crates/native-tls/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ use futures_util::stream::StreamExt;
use native_tls;
use native_tls::{Certificate, TlsAcceptor};
use tokio::runtime::Runtime;
use tokio::net::TcpStream as TokioTcpStream;

use trust_dns_proto::xfer::SerialMessage;
use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, xfer::SerialMessage};

#[allow(clippy::useless_attribute)]
#[allow(unused)]
Expand Down Expand Up @@ -193,7 +194,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let trust_chain = Certificate::from_der(&root_cert_der).unwrap();

// barrier.wait();
let mut builder = TlsStreamBuilder::new();
let mut builder = TlsStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
builder.add_ca(trust_chain);

// fix MTLS
Expand Down
15 changes: 7 additions & 8 deletions crates/native-tls/src/tls_client_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ use futures_util::TryFutureExt;
use native_tls::Certificate;
#[cfg(feature = "mtls")]
use native_tls::Pkcs12;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::TlsStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::xfer::BufDnsStreamHandle;
Expand All @@ -28,14 +27,14 @@ use crate::TlsStreamBuilder;
/// TlsClientStream secure DNS over TCP stream
///
/// See TlsClientStreamBuilder::new()
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// Builder for TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder fo the construction of a TlsClientStream
pub fn new() -> TlsClientStreamBuilder {
pub fn new() -> TlsClientStreamBuilder<S> {
TlsClientStreamBuilder(TlsStreamBuilder::new())
}

Expand Down Expand Up @@ -64,7 +63,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -81,7 +80,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
34 changes: 19 additions & 15 deletions crates/native-tls/src/tls_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,25 @@

//! Base TlsStream

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::TryFutureExt;
use native_tls::Protocol::Tlsv12;
use native_tls::{Certificate, Identity, TlsConnector};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver};
use trust_dns_proto::{
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
tcp::Connect,
};

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsConnector> {
let mut builder = TlsConnector::builder();
Expand All @@ -47,10 +49,10 @@ fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsC
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_from_stream(
stream: TokioTlsStream<TokioTcpStream>,
pub fn tls_from_stream<S: Connect>(
stream: TokioTlsStream<AsyncIoStdAsTokio<S>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (TlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();

let stream = TcpStream::from_stream_with_receiver(
Expand All @@ -64,17 +66,19 @@ pub fn tls_from_stream(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<Certificate>,
identity: Option<Identity>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// Constructs a new TlsStreamBuilder
pub fn new() -> TlsStreamBuilder {
pub fn new() -> TlsStreamBuilder<S> {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -123,7 +127,7 @@ impl TlsStreamBuilder {
dns_name: String,
) -> (
// TODO: change to impl?
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand All @@ -137,17 +141,17 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
outbound_messages: StreamReceiver,
) -> Result<TlsStream, io::Error> {
) -> Result<TlsStream<S>, io::Error> {
use crate::tls_stream;

let ca_chain = self.ca_chain.clone();
let identity = self.identity;

let tcp_stream = tcp::tokio::connect(&name_server).await;
let tcp_stream = S::connect(name_server).await;

// TODO: for some reason the above wouldn't accept a ?
let tcp_stream = match tcp_stream {
Ok(tcp_stream) => tcp_stream,
Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream),
Err(err) => return Err(err),
};

Expand Down
2 changes: 1 addition & 1 deletion crates/resolver/src/name_server/connection_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ where
let (stream, handle) =
{ crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name, client_config) };
#[cfg(not(feature = "dns-over-rustls"))]
let (stream, handle) = { crate::tls::new_tls_stream(socket_addr, tls_dns_name) };
let (stream, handle) = { crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name) };

let dns_conn = DnsMultiplexer::with_timeout(
stream,
Expand Down
6 changes: 4 additions & 2 deletions crates/resolver/src/tls/dns_over_native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ use proto::error::ProtoError;
use proto::BufDnsStreamHandle;
use trust_dns_native_tls::{TlsClientStream, TlsClientStreamBuilder};

use crate::name_server::RuntimeProvider;

#[allow(clippy::type_complexity)]
pub(crate) fn new_tls_stream(
pub(crate) fn new_tls_stream<R: RuntimeProvider>(
socket_addr: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<R::Tcp>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let tls_builder = TlsClientStreamBuilder::new();
Expand Down

0 comments on commit 7be8e7c

Please sign in to comment.