Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QuicClientStreamBuilder::endpoint() #2003

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 87 additions & 34 deletions crates/proto/src/quic/quic_client_stream.rs
Expand Up @@ -147,22 +147,67 @@ impl Stream for QuicClientStream {

/// A QUIC connection builder for DNS-over-QUIC
#[derive(Clone)]
pub struct QuicClientStreamBuilder {
crypto_config: TlsClientConfig,
transport_config: Arc<TransportConfig>,
bind_addr: Option<SocketAddr>,
pub struct QuicClientStreamBuilder(Config);

#[derive(Clone)]
enum Config {
Partly {
crypto_config: TlsClientConfig,
transport_config: Arc<TransportConfig>,
bind_addr: Option<SocketAddr>,
},
Complete(Endpoint),
}

impl QuicClientStreamBuilder {
/// Constructs a new TlsStreamBuilder with the associated [`Endpoint`]
pub fn endpoint(&mut self, endpoint: Endpoint) -> &mut Self {
self.0 = Config::Complete(endpoint);
self
}

/// Constructs a new TlsStreamBuilder with the associated ClientConfig
pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
self.crypto_config = crypto_config;
pub fn crypto_config(&mut self, new_crypto_config: TlsClientConfig) -> &mut Self {
match self.0 {
Config::Partly {
ref mut crypto_config,
..
} => *crypto_config = new_crypto_config,
Config::Complete(_) => {
let mut transport_config = quic_config::transport();
// clients never accept new bidirectional streams
transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));

self.0 = Config::Partly {
crypto_config: new_crypto_config,
transport_config: Arc::new(transport_config),
bind_addr: None,
};
}
}

self
}

/// Sets the address to connect from.
pub fn bind_addr(&mut self, bind_addr: SocketAddr) -> &mut Self {
self.bind_addr = Some(bind_addr);
pub fn bind_addr(&mut self, new_bind_addr: SocketAddr) -> &mut Self {
match self.0 {
Config::Partly {
ref mut bind_addr, ..
} => *bind_addr = Some(new_bind_addr),
Config::Complete(_) => {
let mut transport_config = quic_config::transport();
// clients never accept new bidirectional streams
transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));

self.0 = Config::Partly {
crypto_config: client_config_tls13_webpki_roots(),
transport_config: Arc::new(transport_config),
bind_addr: Some(new_bind_addr),
};
}
}

self
}

Expand Down Expand Up @@ -217,16 +262,22 @@ impl QuicClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> Result<QuicClientStream, ProtoError> {
let connect = if let Some(bind_addr) = self.bind_addr {
<tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
} else {
<tokio::net::UdpSocket as UdpSocket>::connect(name_server)
let endpoint = match &self.0 {
Config::Partly { bind_addr, .. } => {
let connect = if let Some(bind_addr) = bind_addr {
<tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, *bind_addr)
} else {
<tokio::net::UdpSocket as UdpSocket>::connect(name_server)
};

let socket = connect.await?;
let socket = socket.into_std()?;
let endpoint_config = quic_config::endpoint();
Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?
}
Config::Complete(endpoint) => endpoint.clone(),
};

let socket = connect.await?;
let socket = socket.into_std()?;
let endpoint_config = quic_config::endpoint();
let endpoint = Endpoint::new(endpoint_config, None, socket, Arc::new(quinn::TokioRuntime))?;
self.connect_inner(endpoint, name_server, dns_name).await
}

Expand All @@ -236,28 +287,30 @@ impl QuicClientStreamBuilder {
name_server: SocketAddr,
dns_name: String,
) -> Result<QuicClientStream, ProtoError> {
// ensure the ALPN protocol is set correctly
let mut crypto_config = self.crypto_config;
if crypto_config.alpn_protocols.is_empty() {
crypto_config.alpn_protocols = vec![quic_stream::DOQ_ALPN.to_vec()];
}
let early_data_enabled = crypto_config.enable_early_data;
if let Config::Partly {
crypto_config,
transport_config,
..
} = self.0
{
// ensure the ALPN protocol is set correctly
let mut crypto_config = crypto_config;
if crypto_config.alpn_protocols.is_empty() {
crypto_config.alpn_protocols = vec![quic_stream::DOQ_ALPN.to_vec()];
}

let mut client_config = ClientConfig::new(Arc::new(crypto_config));
client_config.transport_config(self.transport_config.clone());
let mut client_config = ClientConfig::new(Arc::new(crypto_config));
client_config.transport_config(transport_config.clone());

endpoint.set_default_client_config(client_config);
endpoint.set_default_client_config(client_config);
}

let connecting = endpoint.connect(name_server, &dns_name)?;
// TODO: for Client/Dynamic update, don't use RTT, for queries, do use it.

let quic_connection = if early_data_enabled {
match connecting.into_0rtt() {
Ok((new_connection, _)) => new_connection,
Err(connecting) => connecting.await?,
}
} else {
connecting.await?
let quic_connection = match connecting.into_0rtt() {
daxpedda marked this conversation as resolved.
Show resolved Hide resolved
Ok((new_connection, _)) => new_connection,
Err(connecting) => connecting.await?,
};

Ok(QuicClientStream {
Expand Down Expand Up @@ -298,11 +351,11 @@ impl Default for QuicClientStreamBuilder {

let client_config = client_config_tls13_webpki_roots();

Self {
Self(Config::Partly {
crypto_config: client_config,
transport_config: Arc::new(transport_config),
bind_addr: None,
}
})
}
}

Expand Down