diff --git a/bin/tests/named_openssl_tests.rs b/bin/tests/named_openssl_tests.rs index 4c4f68b263..fac1750fdd 100644 --- a/bin/tests/named_openssl_tests.rs +++ b/bin/tests/named_openssl_tests.rs @@ -22,12 +22,14 @@ use std::io::*; use std::net::*; use native_tls::Certificate; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; use trust_dns_client::client::*; use trust_dns_native_tls::TlsClientStreamBuilder; use server_harness::{named_test_harness, query_a}; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; #[test] fn test_example_tls_toml_startup() { @@ -59,7 +61,8 @@ fn test_startup(toml: &'static str) { .unwrap() .next() .unwrap(); - let mut tls_conn_builder = TlsClientStreamBuilder::new(); + let mut tls_conn_builder = + TlsClientStreamBuilder::>::new(); let cert = to_trust_anchor(&cert_der); tls_conn_builder.add_ca(cert); let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string()); @@ -74,7 +77,8 @@ fn test_startup(toml: &'static str) { .unwrap() .next() .unwrap(); - let mut tls_conn_builder = TlsClientStreamBuilder::new(); + let mut tls_conn_builder = + TlsClientStreamBuilder::>::new(); let cert = to_trust_anchor(&cert_der); tls_conn_builder.add_ca(cert); let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string()); diff --git a/crates/native-tls/src/tests.rs b/crates/native-tls/src/tests.rs index c4b3dad028..a1eefdb41a 100644 --- a/crates/native-tls/src/tests.rs +++ b/crates/native-tls/src/tests.rs @@ -26,8 +26,8 @@ use std::{thread, time}; use futures_util::stream::StreamExt; use native_tls; use native_tls::{Certificate, TlsAcceptor}; -use tokio::runtime::Runtime; use tokio::net::TcpStream as TokioTcpStream; +use tokio::runtime::Runtime; use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, xfer::SerialMessage}; diff --git a/crates/native-tls/src/tls_client_stream.rs b/crates/native-tls/src/tls_client_stream.rs index 21fa8a5e67..13164e1b9b 100644 --- a/crates/native-tls/src/tls_client_stream.rs +++ b/crates/native-tls/src/tls_client_stream.rs @@ -17,9 +17,10 @@ use native_tls::Certificate; use native_tls::Pkcs12; use tokio_native_tls::TlsStream as TokioTlsStream; -use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect}; +use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use crate::TlsStreamBuilder; @@ -27,7 +28,8 @@ use crate::TlsStreamBuilder; /// TlsClientStream secure DNS over TCP stream /// /// See TlsClientStreamBuilder::new() -pub type TlsClientStream = TcpClientStream>>>; +pub type TlsClientStream = + TcpClientStream>>>; /// Builder for TlsClientStream pub struct TlsClientStreamBuilder(TlsStreamBuilder); diff --git a/crates/openssl/src/tls_client_stream.rs b/crates/openssl/src/tls_client_stream.rs index 9ddec8a937..a191158ee5 100644 --- a/crates/openssl/src/tls_client_stream.rs +++ b/crates/openssl/src/tls_client_stream.rs @@ -14,23 +14,24 @@ use futures_util::TryFutureExt; #[cfg(feature = "mtls")] use openssl::pkcs12::Pkcs12; use openssl::x509::X509; -use tokio::net::TcpStream as TokioTcpStream; use tokio_openssl::SslStream as TokioTlsStream; use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use super::TlsStreamBuilder; /// A Type definition for the TLS stream -pub type TlsClientStream = TcpClientStream>>; +pub type TlsClientStream = + TcpClientStream>>>; /// A Builder for the TlsClientStream -pub struct TlsClientStreamBuilder(TlsStreamBuilder); +pub struct TlsClientStreamBuilder(TlsStreamBuilder); -impl TlsClientStreamBuilder { +impl TlsClientStreamBuilder { /// Creates a builder for the construction of a TlsClientStream. pub fn new() -> Self { TlsClientStreamBuilder(TlsStreamBuilder::new()) @@ -71,7 +72,7 @@ impl TlsClientStreamBuilder { name_server: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let (stream_future, sender) = self.0.build(name_server, dns_name); @@ -88,7 +89,7 @@ impl TlsClientStreamBuilder { } } -impl Default for TlsClientStreamBuilder { +impl Default for TlsClientStreamBuilder { fn default() -> Self { Self::new() } diff --git a/crates/openssl/src/tls_stream.rs b/crates/openssl/src/tls_stream.rs index 0ed11db89d..3548df4e27 100644 --- a/crates/openssl/src/tls_stream.rs +++ b/crates/openssl/src/tls_stream.rs @@ -5,10 +5,10 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use std::future::Future; use std::io; use std::net::SocketAddr; use std::pin::Pin; +use std::{future::Future, marker::PhantomData}; use futures_util::{future, TryFutureExt}; use openssl::pkcs12::ParsedPkcs12; @@ -17,12 +17,14 @@ use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMet use openssl::stack::Stack; use openssl::x509::store::X509StoreBuilder; use openssl::x509::{X509Ref, X509}; -use tokio::net::TcpStream as TokioTcpStream; use tokio_openssl::{self, SslStream as TokioTlsStream}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::{self, TcpStream}; +use trust_dns_proto::tcp::TcpStream; use trust_dns_proto::xfer::BufStreamHandle; +use trust_dns_proto::{ + iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd}, + tcp::Connect, +}; pub trait TlsIdentityExt { fn identity(&mut self, pkcs12: &ParsedPkcs12) -> io::Result<()> { @@ -57,7 +59,8 @@ impl TlsIdentityExt for SslContextBuilder { } /// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream -pub type TlsStream = TcpStream>>; +pub type TlsStream = TcpStream>>; +pub type CompatTlsStream = TlsStream>; fn new(certs: Vec, pkcs12: Option) -> io::Result { let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| { @@ -115,21 +118,21 @@ fn new(certs: Vec, pkcs12: Option) -> io::Result>, +pub fn tls_stream_from_existing_tls_stream( + stream: AsyncIoTokioAsStd>>, peer_addr: SocketAddr, -) -> (TlsStream, BufStreamHandle) { +) -> (CompatTlsStream, BufStreamHandle) { let (message_sender, outbound_messages) = BufStreamHandle::create(); let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages); (stream, message_sender) } -async fn connect_tls( +async fn connect_tls( tls_config: ConnectConfiguration, dns_name: String, name_server: SocketAddr, -) -> Result, io::Error> { - let tcp = tcp::tokio::connect(&name_server).await.map_err(|e| { +) -> Result>, io::Error> { + let tcp = S::connect(name_server).await.map_err(|e| { io::Error::new( io::ErrorKind::ConnectionRefused, format!("tls error: {}", e), @@ -137,7 +140,7 @@ async fn connect_tls( })?; let mut stream = tls_config .into_ssl(&dns_name) - .and_then(|ssl| TokioTlsStream::new(ssl, tcp)) + .and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp))) .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {}", e)))?; Pin::new(&mut stream).connect().await.map_err(|e| { io::Error::new( @@ -150,17 +153,19 @@ async fn connect_tls( /// A builder for the TlsStream #[derive(Default)] -pub struct TlsStreamBuilder { +pub struct TlsStreamBuilder { ca_chain: Vec, identity: Option, + marker: PhantomData, } -impl TlsStreamBuilder { +impl TlsStreamBuilder { /// A builder for associating trust information to the `TlsStream`. pub fn new() -> Self { TlsStreamBuilder { ca_chain: vec![], identity: None, + marker: PhantomData, } } @@ -208,7 +213,7 @@ impl TlsStreamBuilder { name_server: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, io::Error>> + Send>>, BufStreamHandle, ) { let (message_sender, outbound_messages) = BufStreamHandle::create(); diff --git a/crates/openssl/tests/openssl_tests.rs b/crates/openssl/tests/openssl_tests.rs index bffd2015c7..98e536cfec 100644 --- a/crates/openssl/tests/openssl_tests.rs +++ b/crates/openssl/tests/openssl_tests.rs @@ -19,6 +19,7 @@ use openssl::pkey::*; use openssl::ssl::*; use openssl::x509::store::X509StoreBuilder; use openssl::x509::*; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; use openssl::asn1::*; @@ -29,7 +30,7 @@ use openssl::pkcs12::*; use openssl::rsa::*; use openssl::x509::extension::*; -use trust_dns_proto::xfer::SerialMessage; +use trust_dns_proto::{iocompat::AsyncIoTokioAsStd, tcp::Connect, xfer::SerialMessage}; use trust_dns_openssl::TlsStreamBuilder; @@ -198,7 +199,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { let trust_chain = X509::from_der(&root_cert_der).unwrap(); // barrier.wait(); - let mut builder = TlsStreamBuilder::new(); + let mut builder = TlsStreamBuilder::>::new(); builder.add_ca(trust_chain); if mtls { @@ -228,11 +229,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { } #[allow(unused_variables)] -fn config_mtls( +fn config_mtls( root_pkey: &PKey, root_name: &X509Name, root_cert: &X509, - builder: &mut TlsStreamBuilder, + builder: &mut TlsStreamBuilder, ) { #[cfg(feature = "mtls")] { diff --git a/crates/resolver/src/name_server/connection_provider.rs b/crates/resolver/src/name_server/connection_provider.rs index afa459703e..9430d73ce7 100644 --- a/crates/resolver/src/name_server/connection_provider.rs +++ b/crates/resolver/src/name_server/connection_provider.rs @@ -28,10 +28,7 @@ use proto; use proto::error::ProtoError; #[cfg(feature = "tokio-runtime")] -use proto::{ - iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd}, - TokioTime, -}; +use proto::{iocompat::AsyncIoTokioAsStd, TokioTime}; #[cfg(feature = "mdns")] use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType}; @@ -158,7 +155,8 @@ where let (stream, handle) = { crate::tls::new_tls_stream::(socket_addr, tls_dns_name, client_config) }; #[cfg(not(feature = "dns-over-rustls"))] - let (stream, handle) = { crate::tls::new_tls_stream::(socket_addr, tls_dns_name) }; + let (stream, handle) = + { crate::tls::new_tls_stream::(socket_addr, tls_dns_name) }; let dns_conn = DnsMultiplexer::with_timeout( stream, @@ -210,7 +208,7 @@ where #[cfg(feature = "dns-over-tls")] /// Predefined type for TLS client stream -type TlsClientStream = TcpClientStream>>>; +type TlsClientStream = TcpClientStream>>>; /// The variants of all supported connections for the Resolver #[allow(clippy::large_enum_variant, clippy::type_complexity)] diff --git a/crates/resolver/src/tls/dns_over_openssl.rs b/crates/resolver/src/tls/dns_over_openssl.rs index 94aa9a404c..4289ce021c 100644 --- a/crates/resolver/src/tls/dns_over_openssl.rs +++ b/crates/resolver/src/tls/dns_over_openssl.rs @@ -17,12 +17,14 @@ use proto::error::ProtoError; use proto::BufDnsStreamHandle; use trust_dns_openssl::{TlsClientStream, TlsClientStreamBuilder}; +use crate::name_server::RuntimeProvider; + #[allow(clippy::type_complexity)] -pub(crate) fn new_tls_stream( +pub(crate) fn new_tls_stream( socket_addr: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let tls_builder = TlsClientStreamBuilder::new(); diff --git a/crates/resolver/src/tls/dns_over_rustls.rs b/crates/resolver/src/tls/dns_over_rustls.rs index 471c7c156a..e95d5db9f7 100644 --- a/crates/resolver/src/tls/dns_over_rustls.rs +++ b/crates/resolver/src/tls/dns_over_rustls.rs @@ -19,8 +19,8 @@ use proto::error::ProtoError; use proto::BufDnsStreamHandle; use trust_dns_rustls::{tls_client_connect, TlsClientStream}; -use crate::name_server::RuntimeProvider; use crate::config::TlsClientConfig; +use crate::name_server::RuntimeProvider; const ALPN_H2: &[u8] = b"h2"; diff --git a/crates/rustls/src/tls_client_stream.rs b/crates/rustls/src/tls_client_stream.rs index 11d335a88f..db560235ea 100644 --- a/crates/rustls/src/tls_client_stream.rs +++ b/crates/rustls/src/tls_client_stream.rs @@ -15,9 +15,10 @@ use std::sync::Arc; use futures_util::TryFutureExt; use rustls::ClientConfig; -use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoStdAsTokio, tcp::Connect}; +use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use crate::tls_stream::tls_connect; diff --git a/tests/integration-tests/tests/server_future_tests.rs b/tests/integration-tests/tests/server_future_tests.rs index 8efdd9b27d..d459016ce6 100644 --- a/tests/integration-tests/tests/server_future_tests.rs +++ b/tests/integration-tests/tests/server_future_tests.rs @@ -7,6 +7,7 @@ use std::time::Duration; use futures::{future, Future, FutureExt}; use tokio::net::TcpListener; +use tokio::net::TcpStream as TokioTcpStream; use tokio::net::UdpSocket; use tokio::runtime::Runtime; @@ -15,8 +16,8 @@ use trust_dns_client::op::*; use trust_dns_client::rr::*; use trust_dns_client::tcp::TcpClientConnection; use trust_dns_client::udp::UdpClientConnection; -use trust_dns_proto::error::ProtoError; use trust_dns_proto::xfer::DnsRequestSender; +use trust_dns_proto::{error::ProtoError, iocompat::AsyncIoTokioAsStd}; use trust_dns_server::authority::{Authority, Catalog}; use trust_dns_server::ServerFuture; @@ -194,7 +195,11 @@ fn lazy_tcp_client(ipaddr: SocketAddr) -> TcpClientConnection { } #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))] -fn lazy_tls_client(ipaddr: SocketAddr, dns_name: String, cert_der: Vec) -> TlsClientConnection { +fn lazy_tls_client( + ipaddr: SocketAddr, + dns_name: String, + cert_der: Vec, +) -> TlsClientConnection> { use rustls::{Certificate, ClientConfig}; let trust_chain = Certificate(cert_der);