diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index 6e6eb771d063..9c363dce42e8 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -59,6 +59,7 @@ use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener}, pin::Pin, + sync::{Arc, RwLock}, task::{Context, Poll}, time::Duration, }; @@ -95,7 +96,7 @@ enum PortReuse { Enabled { /// The addresses and ports of the listening sockets /// registered as eligible for port reuse when dialing. - listen_addrs: HashSet<(IpAddr, Port)>, + listen_addrs: Arc>>, }, } @@ -106,7 +107,10 @@ impl PortReuse { fn register(&mut self, ip: IpAddr, port: Port) { if let PortReuse::Enabled { listen_addrs } = self { log::trace!("Registering for port reuse: {}:{}", ip, port); - listen_addrs.insert((ip, port)); + listen_addrs + .write() + .expect("`register()` and `unregister()` never panic while holding the lock") + .insert((ip, port)); } } @@ -116,7 +120,10 @@ impl PortReuse { fn unregister(&mut self, ip: IpAddr, port: Port) { if let PortReuse::Enabled { listen_addrs } = self { log::trace!("Unregistering for port reuse: {}:{}", ip, port); - listen_addrs.remove(&(ip, port)); + listen_addrs + .write() + .expect("`register()` and `unregister()` never panic while holding the lock") + .remove(&(ip, port)); } } @@ -131,7 +138,11 @@ impl PortReuse { /// listening socket address is found. fn local_dial_addr(&self, remote_ip: &IpAddr) -> Option { if let PortReuse::Enabled { listen_addrs } = self { - for (ip, port) in listen_addrs.iter() { + for (ip, port) in listen_addrs + .read() + .expect("`local_dial_addr` never panic while holding the lock") + .iter() + { if ip.is_ipv4() == remote_ip.is_ipv4() && ip.is_loopback() == remote_ip.is_loopback() { @@ -286,7 +297,7 @@ where pub fn port_reuse(mut self, port_reuse: bool) -> Self { self.port_reuse = if port_reuse { PortReuse::Enabled { - listen_addrs: HashSet::new(), + listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled @@ -900,7 +911,11 @@ mod tests { fn port_reuse_dialing() { env_logger::try_init().ok(); - async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { + async fn listener( + addr: Multiaddr, + mut ready_tx: mpsc::Sender, + mut port_reuse_rx: mpsc::Receiver>, + ) { let mut tcp = GenTcpConfig::::new(); let mut listener = tcp.listen_on(addr).unwrap(); loop { @@ -908,7 +923,16 @@ mod tests { ListenerEvent::NewAddress(listen_addr) => { ready_tx.send(listen_addr).await.ok(); } - ListenerEvent::Upgrade { upgrade, .. } => { + ListenerEvent::Upgrade { + upgrade, + local_addr: _, + mut remote_addr, + } => { + // Receive the dialer tcp port reuse + let remote_port_reuse = port_reuse_rx.next().await.unwrap(); + // And check it is the same as the remote port used for upgrade + assert_eq!(remote_addr.pop().unwrap(), remote_port_reuse); + let mut upgrade = upgrade.await.unwrap(); let mut buf = [0u8; 3]; upgrade.read_exact(&mut buf).await.unwrap(); @@ -921,12 +945,30 @@ mod tests { } } - async fn dialer(addr: Multiaddr, mut ready_rx: mpsc::Receiver) { + async fn dialer( + addr: Multiaddr, + mut ready_rx: mpsc::Receiver, + mut port_reuse_tx: mpsc::Sender>, + ) { let dest_addr = ready_rx.next().await.unwrap(); let mut tcp = GenTcpConfig::::new().port_reuse(true); let mut listener = tcp.clone().listen_on(addr).unwrap(); match listener.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(_) => { + // Check that tcp and listener share the same port reuse SocketAddr + let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip()); + let port_reuse_listener = listener + .port_reuse + .local_dial_addr(&listener.listen_addr.ip()); + assert!(port_reuse_tcp.is_some()); + assert_eq!(port_reuse_tcp, port_reuse_listener); + + // Send the dialer tcp port reuse to the listener + port_reuse_tx + .send(Protocol::Tcp(port_reuse_tcp.unwrap().port())) + .await + .ok(); + // Obtain a future socket through dialing let mut socket = tcp.dial(dest_addr).unwrap().await.unwrap(); socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap(); @@ -943,8 +985,9 @@ mod tests { #[cfg(feature = "async-io")] { let (ready_tx, ready_rx) = mpsc::channel(1); - let listener = listener::(addr.clone(), ready_tx); - let dialer = dialer::(addr.clone(), ready_rx); + let (port_reuse_tx, port_reuse_rx) = mpsc::channel(1); + let listener = listener::(addr.clone(), ready_tx, port_reuse_rx); + let dialer = dialer::(addr.clone(), ready_rx, port_reuse_tx); let listener = async_std::task::spawn(listener); async_std::task::block_on(dialer); async_std::task::block_on(listener); @@ -953,8 +996,9 @@ mod tests { #[cfg(feature = "tokio")] { let (ready_tx, ready_rx) = mpsc::channel(1); - let listener = listener::(addr.clone(), ready_tx); - let dialer = dialer::(addr.clone(), ready_rx); + let (port_reuse_tx, port_reuse_rx) = mpsc::channel(1); + let listener = listener::(addr.clone(), ready_tx, port_reuse_rx); + let dialer = dialer::(addr.clone(), ready_rx, port_reuse_tx); let rt = tokio_crate::runtime::Builder::new_current_thread() .enable_io() .build() @@ -979,6 +1023,15 @@ mod tests { let mut listener1 = tcp.clone().listen_on(addr).unwrap(); match listener1.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(addr1) => { + // Check that tcp and listener share the same port reuse SocketAddr + let port_reuse_tcp = + tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip()); + let port_reuse_listener1 = listener1 + .port_reuse + .local_dial_addr(&listener1.listen_addr.ip()); + assert!(port_reuse_tcp.is_some()); + assert_eq!(port_reuse_tcp, port_reuse_listener1); + // Listen on the same address a second time. let mut listener2 = tcp.clone().listen_on(addr1.clone()).unwrap(); match listener2.next().await.unwrap().unwrap() {