Skip to content

Commit

Permalink
Make H3ClientStream Clonable
Browse files Browse the repository at this point in the history
  • Loading branch information
0xffffharry committed Apr 15, 2024
1 parent 94ac564 commit 816dfc3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions crates/proto/src/h3/h3_client_stream.rs
Expand Up @@ -10,7 +10,7 @@ use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, Mutex, TryLockError};
use std::task::{Context, Poll};

use bytes::{Buf, BufMut, Bytes, BytesMut};
Expand All @@ -35,11 +35,12 @@ use super::ALPN_H3;

/// A DNS client connection for DNS-over-HTTP/3
#[must_use = "futures do nothing unless polled"]
#[derive(Clone)]
pub struct H3ClientStream {
// Corresponds to the dns-name of the HTTP/3 server
name_server_name: Arc<str>,
name_server: SocketAddr,
driver: Connection<h3_quinn::Connection, Bytes>,
driver: Arc<Mutex<Connection<h3_quinn::Connection, Bytes>>>,
send_request: SendRequest<OpenStreams, Bytes>,
is_shutdown: bool,
}
Expand Down Expand Up @@ -264,13 +265,19 @@ impl DnsRequestSender for H3ClientStream {
impl Stream for H3ClientStream {
type Item = Result<(), ProtoError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}

let mut driver_locker = match self.driver.try_lock() {
Ok(locker) => locker,
Err(TryLockError::WouldBlock) => return Poll::Pending,
Err(e) => return Poll::Ready(Some(Err(ProtoError::from(format!("driver: lock failed: {e}"))))),
};

// just checking if the connection is ok
match self.driver.poll_close(cx) {
match driver_locker.poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
Expand Down Expand Up @@ -405,7 +412,7 @@ impl H3ClientStreamBuilder {
Ok(H3ClientStream {
name_server_name: Arc::from(dns_name),
name_server,
driver,
driver: Arc::new(Mutex::new(driver)),
send_request,
is_shutdown: false,
})
Expand Down

0 comments on commit 816dfc3

Please sign in to comment.