Skip to content

Commit

Permalink
Use try_clone UdpSocket instead of PORT_REUSE
Browse files Browse the repository at this point in the history
  • Loading branch information
kpp committed Jun 7, 2023
1 parent c48b8e6 commit af684ce
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 36 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion transports/quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ parking_lot = "0.12.0"
quinn = { version = "0.10.1", default-features = false, features = ["tls-rustls", "futures-io"] }
rand = "0.8.5"
rustls = { version = "0.21.1", default-features = false }
socket2 = { version = "0.5.3", features = ["all"] }
thiserror = "1.0.40"
tokio = { version = "1.28.1", default-features = false, features = ["net", "rt"], optional = true }

Expand Down
16 changes: 6 additions & 10 deletions transports/quic/src/hole_punching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rand::{distributions, Rng};
use std::{
collections::HashMap,
io,
net::SocketAddr,
net::{SocketAddr, UdpSocket},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
Expand Down Expand Up @@ -65,25 +65,21 @@ impl FusedFuture for MaybeHolePunchedConnection {
}

pub(crate) struct HolePuncher {
socket: socket2::Socket,
socket: UdpSocket,
remote_addr: SocketAddr,
timeout: Delay,
interval_timeout: Delay,
}

impl HolePuncher {
pub(crate) fn new(
local_addr: SocketAddr,
socket: UdpSocket,
remote_addr: SocketAddr,
timeout: Duration,
) -> io::Result<Self> {
let domain = socket2::Domain::for_address(remote_addr);
let socket = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
socket.set_reuse_port(true)?;
socket.bind(&local_addr.into())?;
socket.connect(&remote_addr.into())?;

Ok(Self {
socket,
remote_addr,
timeout: Delay::new(timeout),
interval_timeout: Delay::new(Duration::from_secs(0)),
})
Expand All @@ -107,7 +103,7 @@ impl Future for HolePuncher {
.take(64)
.collect();

if let Err(e) = self.socket.send(&contents) {
if let Err(e) = self.socket.send_to(&contents, self.remote_addr) {
if !matches!(e.kind(), io::ErrorKind::WouldBlock) {
return Poll::Ready(Error::Io(e));
}
Expand Down
55 changes: 31 additions & 24 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use std::collections::hash_map::{DefaultHasher, Entry};
use std::collections::HashMap;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket};
use std::time::Duration;
use std::{
net::SocketAddr,
Expand Down Expand Up @@ -98,38 +98,28 @@ impl<P: Provider> GenTransport<P> {
fn new_endpoint(
endpoint_config: quinn::EndpointConfig,
server_config: Option<quinn::ServerConfig>,
socket_addr: SocketAddr,
socket: UdpSocket,
) -> Result<quinn::Endpoint, Error> {
use crate::provider::Runtime;
match P::runtime() {
#[cfg(feature = "tokio")]
Runtime::Tokio => {
let runtime = std::sync::Arc::new(quinn::TokioRuntime);
let domain = socket2::Domain::for_address(socket_addr);
let socket = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
socket.set_reuse_port(true)?;
socket.bind(&socket_addr.into())?;
let socket = socket.into();
let endpoint =
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
Ok(endpoint)
}
#[cfg(feature = "async-std")]
Runtime::AsyncStd => {
let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime);
let domain = socket2::Domain::for_address(socket_addr);
let socket = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
socket.set_reuse_port(true)?;
socket.bind(&socket_addr.into())?;
let socket = socket.into();
let endpoint =
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
Ok(endpoint)
}
Runtime::Dummy => {
let _ = endpoint_config;
let _ = server_config;
let _ = socket_addr;
let _ = socket;
let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found");
Err(Error::Io(err))
}
Expand All @@ -153,13 +143,17 @@ impl<P: Provider> Transport for GenTransport<P> {
.ok_or(TransportError::MultiaddrNotSupported(addr))?;
let endpoint_config = self.quinn_config.endpoint_config.clone();
let server_config = self.quinn_config.server_config.clone();
let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket_addr)?;
let need_if_watcher = socket_addr.ip().is_unspecified();
let socket = UdpSocket::bind(socket_addr).map_err(Self::Error::from)?;
let socket_c = socket.try_clone().map_err(Self::Error::from)?;
let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
let listener = Listener::new(
listener_id,
socket_addr,
socket_c,
endpoint,
self.hole_punch_map.clone(),
self.handshake_timeout,
need_if_watcher,
version,
)?;
self.listeners.push(listener);
Expand Down Expand Up @@ -231,9 +225,10 @@ impl<P: Provider> Transport for GenTransport<P> {
SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
let socket =
UdpSocket::bind(listen_socket_addr).map_err(Self::Error::from)?;
let endpoint_config = self.quinn_config.endpoint_config.clone();
let endpoint =
Self::new_endpoint(endpoint_config, None, listen_socket_addr)?;
let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;

vacant.insert(endpoint.clone());
endpoint
Expand Down Expand Up @@ -293,7 +288,7 @@ impl<P: Provider> Transport for GenTransport<P> {
})
.collect::<Vec<_>>();

let endpoint_addr = match listeners.len() {
let socket = match listeners.len() {
0 => {
return Err(TransportError::MultiaddrNotSupported(addr)); // FIXME return correct error
}
Expand All @@ -303,11 +298,13 @@ impl<P: Provider> Transport for GenTransport<P> {
let mut hasher = DefaultHasher::new();
socket_addr.hash(&mut hasher);
let index = hasher.finish() as usize % listeners.len();
listeners[index].endpoint.local_addr().unwrap()
listeners[index]
.try_clone_socket()
.map_err(Self::Error::from)?
}
};

let hole_puncher = HolePuncher::new(endpoint_addr, socket_addr, self.handshake_timeout)
let hole_puncher = HolePuncher::new(socket, socket_addr, self.handshake_timeout)
.map_err(|e| TransportError::Other(Error::Io(e)))?;

let (sender, receiver) = oneshot::channel();
Expand Down Expand Up @@ -356,6 +353,9 @@ struct Listener<P: Provider> {
/// Endpoint
endpoint: quinn::Endpoint,

/// An underlying copy of the socket to be able to hole punch with
socket: UdpSocket,

/// A future to poll new incoming connections.
accept: BoxFuture<'static, Option<quinn::Connecting>>,
/// Timeout for connection establishment on inbound connections.
Expand All @@ -382,20 +382,21 @@ struct Listener<P: Provider> {
impl<P: Provider> Listener<P> {
fn new(
listener_id: ListenerId,
socket_addr: SocketAddr,
socket: UdpSocket,
endpoint: quinn::Endpoint,
hole_punch_map: HolePunchMap,
handshake_timeout: Duration,
need_if_watcher: bool,
version: ProtocolVersion,
) -> Result<Self, Error> {
let if_watcher;
let pending_event;
if socket_addr.ip().is_unspecified() {
if need_if_watcher {
if_watcher = Some(P::new_if_watcher()?);
pending_event = None;
} else {
if_watcher = None;
let ma = socketaddr_to_multiaddr(&endpoint.local_addr()?, version);
let ma = socketaddr_to_multiaddr(&socket.local_addr()?, version);
pending_event = Some(TransportEvent::NewAddress {
listener_id,
listen_addr: ma,
Expand All @@ -408,6 +409,7 @@ impl<P: Provider> Listener<P> {
Ok(Listener {
endpoint,
hole_punch_map,
socket,
accept,
listener_id,
version,
Expand Down Expand Up @@ -438,8 +440,13 @@ impl<P: Provider> Listener<P> {
}
}

/// Clone underlying socket (for hole punching).
fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
self.socket.try_clone()
}

fn socket_addr(&self) -> SocketAddr {
self.endpoint.local_addr().unwrap()
self.socket.local_addr().unwrap()
}

/// Poll for a next If Event.
Expand Down

0 comments on commit af684ce

Please sign in to comment.