Skip to content

Commit

Permalink
Fix issue #2651 - Make port reuse work again using Arc<RwLock> for
Browse files Browse the repository at this point in the history
listen_addrs
  • Loading branch information
stormshield-pj50 committed May 25, 2022
1 parent ef2afcd commit 3b0c773
Showing 1 changed file with 52 additions and 12 deletions.
64 changes: 52 additions & 12 deletions transports/tcp/src/lib.rs
Expand Up @@ -59,6 +59,7 @@ use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Duration,
};
Expand Down Expand Up @@ -95,7 +96,7 @@ enum PortReuse {
Enabled {
/// The addresses and ports of the listening sockets
/// registered as eligible for port reuse when dialing.
listen_addrs: HashSet<(IpAddr, Port)>,
listen_addrs: Arc<RwLock<HashSet<(IpAddr, Port)>>>,
},
}

Expand All @@ -106,7 +107,10 @@ impl PortReuse {
fn register(&mut self, ip: IpAddr, port: Port) {
if let PortReuse::Enabled { listen_addrs } = self {
log::trace!("Registering for port reuse: {}:{}", ip, port);
listen_addrs.insert((ip, port));
listen_addrs
.write()
.expect("`register()` and `unregister()` never panic while holding the lock")
.insert((ip, port));
}
}

Expand All @@ -116,7 +120,10 @@ impl PortReuse {
fn unregister(&mut self, ip: IpAddr, port: Port) {
if let PortReuse::Enabled { listen_addrs } = self {
log::trace!("Unregistering for port reuse: {}:{}", ip, port);
listen_addrs.remove(&(ip, port));
listen_addrs
.write()
.expect("`register()` and `unregister()` never panic while holding the lock")
.remove(&(ip, port));
}
}

Expand All @@ -131,7 +138,11 @@ impl PortReuse {
/// listening socket address is found.
fn local_dial_addr(&self, remote_ip: &IpAddr) -> Option<SocketAddr> {
if let PortReuse::Enabled { listen_addrs } = self {
for (ip, port) in listen_addrs.iter() {
for (ip, port) in listen_addrs
.read()
.expect("`local_dial_addr` never panic while holding the lock")
.iter()
{
if ip.is_ipv4() == remote_ip.is_ipv4()
&& ip.is_loopback() == remote_ip.is_loopback()
{
Expand Down Expand Up @@ -286,7 +297,7 @@ where
pub fn port_reuse(mut self, port_reuse: bool) -> Self {
self.port_reuse = if port_reuse {
PortReuse::Enabled {
listen_addrs: HashSet::new(),
listen_addrs: Arc::new(RwLock::new(HashSet::new())),
}
} else {
PortReuse::Disabled
Expand Down Expand Up @@ -900,15 +911,21 @@ mod tests {
fn port_reuse_dialing() {
env_logger::try_init().ok();

async fn listener<T: Provider>(addr: Multiaddr, mut ready_tx: mpsc::Sender<Multiaddr>) {
async fn listener<T: Provider>(addr: Multiaddr, mut ready_tx: mpsc::Sender<Multiaddr>,
mut port_reuse_rx: mpsc::Receiver<Protocol<'_>>) {
let mut tcp = GenTcpConfig::<T>::new();
let mut listener = tcp.listen_on(addr).unwrap();
loop {
match listener.next().await.unwrap().unwrap() {
ListenerEvent::NewAddress(listen_addr) => {
ready_tx.send(listen_addr).await.ok();
}
ListenerEvent::Upgrade { upgrade, .. } => {
ListenerEvent::Upgrade { upgrade, local_addr: _, mut remote_addr } => {
// Receive the dialer tcp port reuse
let remote_port_reuse = port_reuse_rx.next().await.unwrap();
// And check it is the same as the remote port used for upgrade
assert_eq!(remote_addr.pop().unwrap(), remote_port_reuse);

let mut upgrade = upgrade.await.unwrap();
let mut buf = [0u8; 3];
upgrade.read_exact(&mut buf).await.unwrap();
Expand All @@ -921,12 +938,24 @@ mod tests {
}
}

async fn dialer<T: Provider>(addr: Multiaddr, mut ready_rx: mpsc::Receiver<Multiaddr>) {
async fn dialer<T: Provider>(addr: Multiaddr, mut ready_rx: mpsc::Receiver<Multiaddr>,
mut port_reuse_tx: mpsc::Sender<Protocol<'_>>) {
let dest_addr = ready_rx.next().await.unwrap();
let mut tcp = GenTcpConfig::<T>::new().port_reuse(true);
let mut listener = tcp.clone().listen_on(addr).unwrap();
match listener.next().await.unwrap().unwrap() {
ListenerEvent::NewAddress(_) => {
// Check that tcp and listener share the same port reuse SocketAddr
let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip());
let port_reuse_listener = listener
.port_reuse
.local_dial_addr(&listener.listen_addr.ip());
assert!(port_reuse_tcp.is_some());
assert_eq!(port_reuse_tcp, port_reuse_listener);

// Send the dialer tcp port reuse to the listener
port_reuse_tx.send(Protocol::Tcp(port_reuse_tcp.unwrap().port())).await.ok();

// Obtain a future socket through dialing
let mut socket = tcp.dial(dest_addr).unwrap().await.unwrap();
socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap();
Expand All @@ -943,8 +972,9 @@ mod tests {
#[cfg(feature = "async-io")]
{
let (ready_tx, ready_rx) = mpsc::channel(1);
let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx);
let dialer = dialer::<async_io::Tcp>(addr.clone(), ready_rx);
let (port_reuse_tx, port_reuse_rx) = mpsc::channel(1);
let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
let dialer = dialer::<async_io::Tcp>(addr.clone(), ready_rx, port_reuse_tx);
let listener = async_std::task::spawn(listener);
async_std::task::block_on(dialer);
async_std::task::block_on(listener);
Expand All @@ -953,8 +983,9 @@ mod tests {
#[cfg(feature = "tokio")]
{
let (ready_tx, ready_rx) = mpsc::channel(1);
let listener = listener::<tokio::Tcp>(addr.clone(), ready_tx);
let dialer = dialer::<tokio::Tcp>(addr.clone(), ready_rx);
let (port_reuse_tx, port_reuse_rx) = mpsc::channel(1);
let listener = listener::<tokio::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
let dialer = dialer::<tokio::Tcp>(addr.clone(), ready_rx, port_reuse_tx);
let rt = tokio_crate::runtime::Builder::new_current_thread()
.enable_io()
.build()
Expand All @@ -979,6 +1010,15 @@ mod tests {
let mut listener1 = tcp.clone().listen_on(addr).unwrap();
match listener1.next().await.unwrap().unwrap() {
ListenerEvent::NewAddress(addr1) => {
// Check that tcp and listener share the same port reuse SocketAddr
let port_reuse_tcp =
tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip());
let port_reuse_listener1 = listener1
.port_reuse
.local_dial_addr(&listener1.listen_addr.ip());
assert!(port_reuse_tcp.is_some());
assert_eq!(port_reuse_tcp, port_reuse_listener1);

// Listen on the same address a second time.
let mut listener2 = tcp.clone().listen_on(addr1.clone()).unwrap();
match listener2.next().await.unwrap().unwrap() {
Expand Down

0 comments on commit 3b0c773

Please sign in to comment.