Skip to content

Commit

Permalink
Add OpenSSL
Browse files Browse the repository at this point in the history
  • Loading branch information
chengyuhui committed Feb 3, 2021
1 parent 7be8e7c commit 9198206
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 45 deletions.
8 changes: 6 additions & 2 deletions bin/tests/named_openssl_tests.rs
Expand Up @@ -22,12 +22,14 @@ use std::io::*;
use std::net::*;

use native_tls::Certificate;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_client::client::*;
use trust_dns_native_tls::TlsClientStreamBuilder;

use server_harness::{named_test_harness, query_a};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;

#[test]
fn test_example_tls_toml_startup() {
Expand Down Expand Up @@ -59,7 +61,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand All @@ -74,7 +77,8 @@ fn test_startup(toml: &'static str) {
.unwrap()
.next()
.unwrap();
let mut tls_conn_builder = TlsClientStreamBuilder::new();
let mut tls_conn_builder =
TlsClientStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
let cert = to_trust_anchor(&cert_der);
tls_conn_builder.add_ca(cert);
let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string());
Expand Down
2 changes: 1 addition & 1 deletion crates/native-tls/src/tests.rs
Expand Up @@ -26,8 +26,8 @@ use std::{thread, time};
use futures_util::stream::StreamExt;
use native_tls;
use native_tls::{Certificate, TlsAcceptor};
use tokio::runtime::Runtime;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, xfer::SerialMessage};

Expand Down
8 changes: 5 additions & 3 deletions crates/native-tls/src/tls_client_stream.rs
Expand Up @@ -17,17 +17,19 @@ use native_tls::Certificate;
use native_tls::Pkcs12;
use tokio_native_tls::TlsStream as TokioTlsStream;

use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect};
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use crate::TlsStreamBuilder;

/// TlsClientStream secure DNS over TCP stream
///
/// See TlsClientStreamBuilder::new()
pub type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// Builder for TlsClientStream
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);
Expand Down
15 changes: 8 additions & 7 deletions crates/openssl/src/tls_client_stream.rs
Expand Up @@ -14,23 +14,24 @@ use futures_util::TryFutureExt;
#[cfg(feature = "mtls")]
use openssl::pkcs12::Pkcs12;
use openssl::x509::X509;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::SslStream as TokioTlsStream;

use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use super::TlsStreamBuilder;

/// A Type definition for the TLS stream
pub type TlsClientStream = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;

/// A Builder for the TlsClientStream
pub struct TlsClientStreamBuilder(TlsStreamBuilder);
pub struct TlsClientStreamBuilder<S>(TlsStreamBuilder<S>);

impl TlsClientStreamBuilder {
impl<S: Connect> TlsClientStreamBuilder<S> {
/// Creates a builder for the construction of a TlsClientStream.
pub fn new() -> Self {
TlsClientStreamBuilder(TlsStreamBuilder::new())
Expand Down Expand Up @@ -71,7 +72,7 @@ impl TlsClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let (stream_future, sender) = self.0.build(name_server, dns_name);
Expand All @@ -88,7 +89,7 @@ impl TlsClientStreamBuilder {
}
}

impl Default for TlsClientStreamBuilder {
impl<S: Connect> Default for TlsClientStreamBuilder<S> {
fn default() -> Self {
Self::new()
}
Expand Down
35 changes: 20 additions & 15 deletions crates/openssl/src/tls_stream.rs
Expand Up @@ -5,10 +5,10 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};

use futures_util::{future, TryFutureExt};
use openssl::pkcs12::ParsedPkcs12;
Expand All @@ -17,12 +17,14 @@ use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMet
use openssl::stack::Stack;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::{X509Ref, X509};
use tokio::net::TcpStream as TokioTcpStream;
use tokio_openssl::{self, SslStream as TokioTlsStream};

use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::{self, TcpStream};
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::BufStreamHandle;
use trust_dns_proto::{
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
tcp::Connect,
};

pub trait TlsIdentityExt {
fn identity(&mut self, pkcs12: &ParsedPkcs12) -> io::Result<()> {
Expand Down Expand Up @@ -57,7 +59,8 @@ impl TlsIdentityExt for SslContextBuilder {
}

/// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream
pub type TlsStream = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>>;
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<S>>>;
pub type CompatTlsStream<S> = TlsStream<AsyncIoStdAsTokio<S>>;

fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnector> {
let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| {
Expand Down Expand Up @@ -115,29 +118,29 @@ fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnecto
/// Initializes a TlsStream with an existing tokio_tls::TlsStream.
///
/// This is intended for use with a TlsListener and Incoming connections
pub fn tls_stream_from_existing_tls_stream(
stream: AsyncIoTokioAsStd<TokioTlsStream<TokioTcpStream>>,
pub fn tls_stream_from_existing_tls_stream<S: Connect>(
stream: AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
) -> (CompatTlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}

async fn connect_tls(
async fn connect_tls<S: Connect>(
tls_config: ConnectConfiguration,
dns_name: String,
name_server: SocketAddr,
) -> Result<TokioTlsStream<TokioTcpStream>, io::Error> {
let tcp = tcp::tokio::connect(&name_server).await.map_err(|e| {
) -> Result<TokioTlsStream<AsyncIoStdAsTokio<S>>, io::Error> {
let tcp = S::connect(name_server).await.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
let mut stream = tls_config
.into_ssl(&dns_name)
.and_then(|ssl| TokioTlsStream::new(ssl, tcp))
.and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp)))
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {}", e)))?;
Pin::new(&mut stream).connect().await.map_err(|e| {
io::Error::new(
Expand All @@ -150,17 +153,19 @@ async fn connect_tls(

/// A builder for the TlsStream
#[derive(Default)]
pub struct TlsStreamBuilder {
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<X509>,
identity: Option<ParsedPkcs12>,
marker: PhantomData<S>,
}

impl TlsStreamBuilder {
impl<S: Connect> TlsStreamBuilder<S> {
/// A builder for associating trust information to the `TlsStream`.
pub fn new() -> Self {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
marker: PhantomData,
}
}

Expand Down Expand Up @@ -208,7 +213,7 @@ impl TlsStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsStream, io::Error>> + Send>>,
Pin<Box<dyn Future<Output = Result<CompatTlsStream<S>, io::Error>> + Send>>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = BufStreamHandle::create();
Expand Down
9 changes: 5 additions & 4 deletions crates/openssl/tests/openssl_tests.rs
Expand Up @@ -19,6 +19,7 @@ use openssl::pkey::*;
use openssl::ssl::*;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::*;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;

use openssl::asn1::*;
Expand All @@ -29,7 +30,7 @@ use openssl::pkcs12::*;
use openssl::rsa::*;
use openssl::x509::extension::*;

use trust_dns_proto::xfer::SerialMessage;
use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, tcp::Connect, xfer::SerialMessage};

use trust_dns_openssl::TlsStreamBuilder;

Expand Down Expand Up @@ -198,7 +199,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
let trust_chain = X509::from_der(&root_cert_der).unwrap();

// barrier.wait();
let mut builder = TlsStreamBuilder::new();
let mut builder = TlsStreamBuilder::<AsyncIoTokioAsStd<TokioTcpStream>>::new();
builder.add_ca(trust_chain);

if mtls {
Expand Down Expand Up @@ -228,11 +229,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) {
}

#[allow(unused_variables)]
fn config_mtls(
fn config_mtls<S: Connect>(
root_pkey: &PKey<Private>,
root_name: &X509Name,
root_cert: &X509,
builder: &mut TlsStreamBuilder,
builder: &mut TlsStreamBuilder<S>,
) {
#[cfg(feature = "mtls")]
{
Expand Down
10 changes: 4 additions & 6 deletions crates/resolver/src/name_server/connection_provider.rs
Expand Up @@ -28,10 +28,7 @@ use proto;
use proto::error::ProtoError;

#[cfg(feature = "tokio-runtime")]
use proto::{
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
TokioTime,
};
use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};

#[cfg(feature = "mdns")]
use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
Expand Down Expand Up @@ -158,7 +155,8 @@ where
let (stream, handle) =
{ crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name, client_config) };
#[cfg(not(feature = "dns-over-rustls"))]
let (stream, handle) = { crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name) };
let (stream, handle) =
{ crate::tls::new_tls_stream::<R>(socket_addr, tls_dns_name) };

let dns_conn = DnsMultiplexer::with_timeout(
stream,
Expand Down Expand Up @@ -210,7 +208,7 @@ where

#[cfg(feature = "dns-over-tls")]
/// Predefined type for TLS client stream
type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
type TlsClientStream<S> = TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;

/// The variants of all supported connections for the Resolver
#[allow(clippy::large_enum_variant, clippy::type_complexity)]
Expand Down
6 changes: 4 additions & 2 deletions crates/resolver/src/tls/dns_over_openssl.rs
Expand Up @@ -17,12 +17,14 @@ use proto::error::ProtoError;
use proto::BufDnsStreamHandle;
use trust_dns_openssl::{TlsClientStream, TlsClientStreamBuilder};

use crate::name_server::RuntimeProvider;

#[allow(clippy::type_complexity)]
pub(crate) fn new_tls_stream(
pub(crate) fn new_tls_stream<R: RuntimeProvider>(
socket_addr: SocketAddr,
dns_name: String,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream, ProtoError>> + Send>>,
Pin<Box<dyn Future<Output = Result<TlsClientStream<R::Tcp>, ProtoError>> + Send>>,
BufDnsStreamHandle,
) {
let tls_builder = TlsClientStreamBuilder::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/resolver/src/tls/dns_over_rustls.rs
Expand Up @@ -19,8 +19,8 @@ use proto::error::ProtoError;
use proto::BufDnsStreamHandle;
use trust_dns_rustls::{tls_client_connect, TlsClientStream};

use crate::name_server::RuntimeProvider;
use crate::config::TlsClientConfig;
use crate::name_server::RuntimeProvider;

const ALPN_H2: &[u8] = b"h2";

Expand Down
5 changes: 3 additions & 2 deletions crates/rustls/src/tls_client_stream.rs
Expand Up @@ -15,9 +15,10 @@ use std::sync::Arc;
use futures_util::TryFutureExt;
use rustls::ClientConfig;

use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect};
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoStdAsTokio;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use trust_dns_proto::tcp::TcpClientStream;
use trust_dns_proto::tcp::{Connect, TcpClientStream};
use trust_dns_proto::xfer::BufDnsStreamHandle;

use crate::tls_stream::tls_connect;
Expand Down
9 changes: 7 additions & 2 deletions tests/integration-tests/tests/server_future_tests.rs
Expand Up @@ -7,6 +7,7 @@ use std::time::Duration;

use futures::{future, Future, FutureExt};
use tokio::net::TcpListener;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::net::UdpSocket;
use tokio::runtime::Runtime;

Expand All @@ -15,8 +16,8 @@ use trust_dns_client::op::*;
use trust_dns_client::rr::*;
use trust_dns_client::tcp::TcpClientConnection;
use trust_dns_client::udp::UdpClientConnection;
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::xfer::DnsRequestSender;
use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoTokioAsStd};

use trust_dns_server::authority::{Authority, Catalog};
use trust_dns_server::ServerFuture;
Expand Down Expand Up @@ -194,7 +195,11 @@ fn lazy_tcp_client(ipaddr: SocketAddr) -> TcpClientConnection {
}

#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
fn lazy_tls_client(ipaddr: SocketAddr, dns_name: String, cert_der: Vec<u8>) -> TlsClientConnection {
fn lazy_tls_client(
ipaddr: SocketAddr,
dns_name: String,
cert_der: Vec<u8>,
) -> TlsClientConnection<AsyncIoTokioAsStd<TokioTcpStream>> {
use rustls::{Certificate, ClientConfig};

let trust_chain = Certificate(cert_der);
Expand Down

0 comments on commit 9198206

Please sign in to comment.