Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make hickory_proto::h3::H3ClientStream Clonable #2182

Merged
merged 1 commit into from May 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
97 changes: 84 additions & 13 deletions crates/proto/src/h3/h3_client_stream.rs
Expand Up @@ -6,7 +6,7 @@
// copied, modified, or distributed except according to those terms.

use std::fmt::{self, Display};
use std::future::Future;
use std::future::{self, Future};
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
Expand All @@ -16,12 +16,13 @@ use std::task::{Context, Poll};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::future::FutureExt;
use futures_util::stream::Stream;
use h3::client::{Connection, SendRequest};
use h3::client::SendRequest;
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
use rustls::ClientConfig as TlsClientConfig;
use tracing::debug;
use tokio::sync::mpsc;
use tracing::{debug, warn};

use crate::error::ProtoError;
use crate::http::Version;
Expand All @@ -34,13 +35,14 @@ use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
use super::ALPN_H3;

/// A DNS client connection for DNS-over-HTTP/3
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
// Corresponds to the dns-name of the HTTP/3 server
name_server_name: Arc<str>,
name_server: SocketAddr,
driver: Connection<h3_quinn::Connection, Bytes>,
send_request: SendRequest<OpenStreams, Bytes>,
shutdown_tx: mpsc::Sender<()>,
is_shutdown: bool,
}

Expand Down Expand Up @@ -264,19 +266,19 @@ impl DnsRequestSender for H3ClientStream {
impl Stream for H3ClientStream {
type Item = Result<(), ProtoError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}

// just checking if the connection is ok
match self.driver.poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h3 stream errored: {e}",
))))),
if self.shutdown_tx.is_closed() {
return Poll::Ready(Some(Err(ProtoError::from(
"h3 connection is already shutdown",
))));
}

Poll::Ready(Some(Ok(())))
}
}

Expand Down Expand Up @@ -398,15 +400,31 @@ impl H3ClientStreamBuilder {
};

let h3_connection = h3_quinn::Connection::new(quic_connection);
let (driver, send_request) = h3::client::new(h3_connection)
let (mut driver, send_request) = h3::client::new(h3_connection)
.await
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;

let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);

// TODO: hand this back for others to run rather than spawning here?
debug!("h3 connection is ready: {}", name_server);
tokio::spawn(async move {
tokio::select! {
res = future::poll_fn(|cx| driver.poll_close(cx)) => {
res.map_err(|e| warn!("h3 connection failed: {e}"))
}
_ = shutdown_rx.recv() => {
debug!("h3 connection is shutting down: {}", name_server);
Ok(())
}
}
});

Ok(H3ClientStream {
name_server_name: Arc::from(dns_name),
name_server,
driver,
send_request,
shutdown_tx,
is_shutdown: false,
})
}
Expand Down Expand Up @@ -453,6 +471,7 @@ mod tests {

use rustls::KeyLogFile;
use tokio::runtime::Runtime;
use tokio::task::JoinSet;

use crate::op::{Message, Query, ResponseCode};
use crate::rr::rdata::{A, AAAA};
Expand Down Expand Up @@ -652,4 +671,56 @@ mod tests {
&AAAA::new(0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c)
);
}

#[test]
#[allow(clippy::print_stdout)]
fn test_h3_client_stream_clonable() {
// use google
let google = SocketAddr::from(([8, 8, 8, 8], 443));

let mut client_config = super::super::client_config_tls13().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());

let mut h3_builder = H3ClientStream::builder();
h3_builder.crypto_config(client_config);
let connect = h3_builder.build(google, "dns.google".to_string());

// tokio runtime stuff...
let runtime = Runtime::new().expect("could not start runtime");
let h3 = runtime.block_on(connect).expect("h3 connect failed");

// prepare request
let mut request = Message::new();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());

runtime.block_on(async move {
let mut join_set = JoinSet::new();

for i in 0..50 {
let mut h3 = h3.clone();
let request = request.clone();

join_set.spawn(async move {
let start = std::time::Instant::now();
h3.send_message(request)
.first_answer()
.await
.expect("send_message failed");
println!("request[{i}] completed: {:?}", start.elapsed());
});
}

let total = join_set.len();
let mut idx = 0usize;
while join_set.join_next().await.is_some() {
println!("join_set completed {idx}/{total}");
idx += 1;
}
});
}
}