Skip to content

Commit

Permalink
Make H3ClientStream Clonable
Browse files Browse the repository at this point in the history
  • Loading branch information
0xffffharry committed Apr 16, 2024
1 parent 94ac564 commit e088e09
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions crates/proto/src/h3/h3_client_stream.rs
Expand Up @@ -6,7 +6,7 @@
// copied, modified, or distributed except according to those terms.

use std::fmt::{self, Display};
use std::future::Future;
use std::future::{self, Future};
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
Expand All @@ -16,12 +16,13 @@ use std::task::{Context, Poll};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::future::FutureExt;
use futures_util::stream::Stream;
use h3::client::{Connection, SendRequest};
use h3::client::SendRequest;
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
use rustls::ClientConfig as TlsClientConfig;
use tracing::debug;
use tokio::sync::mpsc;
use tracing::{debug, warn};

use crate::error::ProtoError;
use crate::http::Version;
Expand All @@ -34,13 +35,14 @@ use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
use super::ALPN_H3;

/// A DNS client connection for DNS-over-HTTP/3
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
// Corresponds to the dns-name of the HTTP/3 server
name_server_name: Arc<str>,
name_server: SocketAddr,
driver: Connection<h3_quinn::Connection, Bytes>,
send_request: SendRequest<OpenStreams, Bytes>,
shutdown_tx: mpsc::Sender<()>,
is_shutdown: bool,
}

Expand Down Expand Up @@ -264,19 +266,19 @@ impl DnsRequestSender for H3ClientStream {
impl Stream for H3ClientStream {
type Item = Result<(), ProtoError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}

// just checking if the connection is ok
match self.driver.poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
"h3 stream errored: {e}",
))))),
if self.shutdown_tx.is_closed() {
return Poll::Ready(Some(Err(ProtoError::from(
"h3 connection is already shutdown",
))));
}

Poll::Ready(Some(Ok(())))
}
}

Expand Down Expand Up @@ -398,15 +400,32 @@ impl H3ClientStreamBuilder {
};

let h3_connection = h3_quinn::Connection::new(quic_connection);
let (driver, send_request) = h3::client::new(h3_connection)
let (mut driver, send_request) = h3::client::new(h3_connection)
.await
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;


let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);

// TODO: hand this back for others to run rather than spawning here?
debug!("h3 connection is ready: {}", name_server);
tokio::spawn(async move {
tokio::select! {
res = future::poll_fn(|cx| driver.poll_close(cx)) => {
res.map_err(|e| warn!("h3 connection failed: {e}"))
}
_ = shutdown_rx.recv() => {
debug!("h3 connection is shutting down: {}", name_server);
Ok(())
}
}
});

Ok(H3ClientStream {
name_server_name: Arc::from(dns_name),
name_server,
driver,
send_request,
shutdown_tx,
is_shutdown: false,
})
}
Expand Down

0 comments on commit e088e09

Please sign in to comment.