diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index 6e6eb771d063..72bdf6a74701 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -59,6 +59,7 @@ use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener}, pin::Pin, + sync::{Arc, RwLock}, task::{Context, Poll}, time::Duration, }; @@ -95,7 +96,7 @@ enum PortReuse { Enabled { /// The addresses and ports of the listening sockets /// registered as eligible for port reuse when dialing. - listen_addrs: HashSet<(IpAddr, Port)>, + listen_addrs: Arc>>, }, } @@ -106,7 +107,10 @@ impl PortReuse { fn register(&mut self, ip: IpAddr, port: Port) { if let PortReuse::Enabled { listen_addrs } = self { log::trace!("Registering for port reuse: {}:{}", ip, port); - listen_addrs.insert((ip, port)); + listen_addrs + .write() + .expect("`register()` and `unregister()` never panic while holding the lock") + .insert((ip, port)); } } @@ -116,7 +120,10 @@ impl PortReuse { fn unregister(&mut self, ip: IpAddr, port: Port) { if let PortReuse::Enabled { listen_addrs } = self { log::trace!("Unregistering for port reuse: {}:{}", ip, port); - listen_addrs.remove(&(ip, port)); + listen_addrs + .write() + .expect("`register()` and `unregister()` never panic while holding the lock") + .remove(&(ip, port)); } } @@ -131,7 +138,11 @@ impl PortReuse { /// listening socket address is found. fn local_dial_addr(&self, remote_ip: &IpAddr) -> Option { if let PortReuse::Enabled { listen_addrs } = self { - for (ip, port) in listen_addrs.iter() { + for (ip, port) in listen_addrs + .read() + .expect("`local_dial_addr` never panic while holding the lock") + .iter() + { if ip.is_ipv4() == remote_ip.is_ipv4() && ip.is_loopback() == remote_ip.is_loopback() { @@ -286,7 +297,7 @@ where pub fn port_reuse(mut self, port_reuse: bool) -> Self { self.port_reuse = if port_reuse { PortReuse::Enabled { - listen_addrs: HashSet::new(), + listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled @@ -927,7 +938,16 @@ mod tests { let mut listener = tcp.clone().listen_on(addr).unwrap(); match listener.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(_) => { + // Check that tcp and listener share the same port reuse SocketAddr + let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip()); + let port_reuse_listener = listener + .port_reuse + .local_dial_addr(&listener.listen_addr.ip()); + assert!(port_reuse_tcp.is_some()); + assert_eq!(port_reuse_tcp, port_reuse_listener); + // Obtain a future socket through dialing + // FIXME: Check that the port used by socket is the same as the port returned by local_dial_addr let mut socket = tcp.dial(dest_addr).unwrap().await.unwrap(); socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap(); // socket.flush().await; @@ -979,6 +999,15 @@ mod tests { let mut listener1 = tcp.clone().listen_on(addr).unwrap(); match listener1.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(addr1) => { + // Check that tcp and listener share the same port reuse SocketAddr + let port_reuse_tcp = + tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip()); + let port_reuse_listener1 = listener1 + .port_reuse + .local_dial_addr(&listener1.listen_addr.ip()); + assert!(port_reuse_tcp.is_some()); + assert_eq!(port_reuse_tcp, port_reuse_listener1); + // Listen on the same address a second time. let mut listener2 = tcp.clone().listen_on(addr1.clone()).unwrap(); match listener2.next().await.unwrap().unwrap() {