From 3a95b5ab525117c516539b670e9f6622d5ca3b10 Mon Sep 17 00:00:00 2001 From: b-naber Date: Mon, 18 Jul 2022 10:30:03 +0200 Subject: [PATCH] address review --- tokio-util/tests/mpsc.rs | 222 +-------------------------------- tokio/src/sync/mpsc/bounded.rs | 2 +- tokio/src/sync/mpsc/chan.rs | 26 ++-- tokio/tests/sync_mpsc.rs | 220 +++++++++++++++++++++++++++++++- 4 files changed, 230 insertions(+), 240 deletions(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index 7f7ccd2ece4..4e4d0090760 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -1,14 +1,6 @@ use futures::future::poll_fn; -use std::ops::Drop; -use std::sync::atomic::{ - AtomicUsize, - Ordering::{Acquire, Release}, -}; -use std::time::Duration; -use tokio::join; -use tokio::sync::mpsc::{self, channel}; -use tokio::sync::oneshot; -use tokio::time; + +use tokio::sync::mpsc::channel; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; @@ -246,213 +238,3 @@ fn start_send_panics_when_acquiring() { assert_pending!(reserve.poll()); send.send_item(2).unwrap(); } - -#[tokio::test] -async fn weak_sender() { - let (tx, mut rx) = channel(11); - let tx_weak = tx.clone().downgrade(); - - let tx_weak = tokio::spawn(async move { - for i in 0..10 { - if tx.send(i).await.is_err() { - return None; - } - } - - let tx2 = tx_weak - .upgrade() - .expect("expected to be able to upgrade tx_weak"); - let _ = tx2.send(20).await; - let tx_weak = tx2.downgrade(); - - Some(tx_weak) - }) - .await - .unwrap(); - - for i in 0..12 { - let recvd = rx.recv().await; - - match recvd { - Some(msg) => { - if i == 10 { - assert_eq!(msg, 20); - } - } - None => { - assert_eq!(i, 11); - break; - } - } - } - - if let Some(tx_weak) = tx_weak { - let upgraded = tx_weak.upgrade(); - assert!(upgraded.is_none()); - } -} - -#[tokio::test] -async fn actor_weak_sender() { - pub struct MyActor { - receiver: mpsc::Receiver, - sender: mpsc::WeakSender, - next_id: u32, - pub received_self_msg: bool, - } - - enum ActorMessage { - GetUniqueId { respond_to: oneshot::Sender }, - SelfMessage {}, - } - - impl MyActor { - fn new( - receiver: mpsc::Receiver, - sender: mpsc::WeakSender, - ) -> Self { - MyActor { - receiver, - sender, - next_id: 0, - received_self_msg: false, - } - } - - fn handle_message(&mut self, msg: ActorMessage) { - match msg { - ActorMessage::GetUniqueId { respond_to } => { - self.next_id += 1; - - // The `let _ =` ignores any errors when sending. - // - // This can happen if the `select!` macro is used - // to cancel waiting for the response. - let _ = respond_to.send(self.next_id); - } - ActorMessage::SelfMessage { .. } => { - self.received_self_msg = true; - } - } - } - - async fn send_message_to_self(&mut self) { - let msg = ActorMessage::SelfMessage {}; - - let sender = self.sender.clone(); - - // cannot move self.sender here - if let Some(sender) = sender.upgrade() { - let _ = sender.send(msg).await; - self.sender = sender.downgrade(); - } - } - - async fn run(&mut self) { - let mut i = 0; - while let Some(msg) = self.receiver.recv().await { - self.handle_message(msg); - - if i == 0 { - self.send_message_to_self().await; - } - - i += 1 - } - - assert!(self.received_self_msg); - } - } - - #[derive(Clone)] - pub struct MyActorHandle { - sender: mpsc::Sender, - } - - impl MyActorHandle { - pub fn new() -> (Self, MyActor) { - let (sender, receiver) = mpsc::channel(8); - let actor = MyActor::new(receiver, sender.clone().downgrade()); - - (Self { sender }, actor) - } - - pub async fn get_unique_id(&self) -> u32 { - let (send, recv) = oneshot::channel(); - let msg = ActorMessage::GetUniqueId { respond_to: send }; - - // Ignore send errors. If this send fails, so does the - // recv.await below. There's no reason to check the - // failure twice. - let _ = self.sender.send(msg).await; - recv.await.expect("Actor task has been killed") - } - } - - let (handle, mut actor) = MyActorHandle::new(); - - let actor_handle = tokio::spawn(async move { actor.run().await }); - - let _ = tokio::spawn(async move { - let _ = handle.get_unique_id().await; - drop(handle); - }) - .await; - - let _ = actor_handle.await; -} - -static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0); - -#[derive(Debug)] -struct Msg; - -impl Drop for Msg { - fn drop(&mut self) { - NUM_DROPPED.fetch_add(1, Release); - } -} - -// Tests that no pending messages are put onto the channel after `Rx` was -// dropped. -// -// Note: After the introduction of `WeakSender`, which internally -// used `Arc` and doesn't call a drop of the channel after the last strong -// `Sender` was dropped while more than one `WeakSender` remains, we want to -// ensure that no messages are kept in the channel, which were sent after -// the receiver was dropped. -#[tokio::test(start_paused = true)] -async fn test_msgs_dropped_on_rx_drop() { - fn ms(millis: u64) -> Duration { - Duration::from_millis(millis) - } - - let (tx, mut rx) = mpsc::channel(3); - - let rx_handle = tokio::spawn(async move { - let _ = rx.recv().await.unwrap(); - let _ = rx.recv().await.unwrap(); - time::sleep(ms(1)).await; - drop(rx); - - time::advance(ms(1)).await; - }); - - let tx_handle = tokio::spawn(async move { - let _ = tx.send(Msg {}).await.unwrap(); - let _ = tx.send(Msg {}).await.unwrap(); - - // This msg will be pending and should be dropped when `rx` is dropped - let _ = tx.send(Msg {}).await.unwrap(); - time::advance(ms(1)).await; - - // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. - time::sleep(ms(1)).await; - let _ = tx.send(Msg {}).await.unwrap(); - - // Ensure that third message isn't put onto the channel anymore - assert_eq!(NUM_DROPPED.load(Acquire), 4); - }); - - let (_, _) = join!(rx_handle, tx_handle); -} diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 92959ac3a2a..5889d8d9215 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1079,7 +1079,7 @@ impl WeakSender { /// if there are other `Sender` instances alive and the channel wasn't /// previously dropped, otherwise `None` is returned. pub fn upgrade(self) -> Option> { - self.chan.upgrade().map(|tx| Sender::new(tx)) + self.chan.upgrade().map(Sender::new) } } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index d097ce1e4a7..f6e9e0fab24 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -146,29 +146,19 @@ impl Tx { pub(super) fn upgrade(self) -> Option { let mut tx_count = self.inner.tx_count.load(Acquire); - if tx_count == 0 { - // channel is closed - return None; - } - loop { + if tx_count == 0 { + // channel is closed + return None; + } + match self .inner .tx_count - .compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire) + .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire) { - Ok(prev_count) => { - assert!(prev_count != 0); - - return Some(self); - } - Err(prev_count) => { - if prev_count == 0 { - return None; - } - - tx_count = prev_count; - } + Ok(_) => return Some(self), + Err(prev_count) => tx_count = prev_count, } } } diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index a1510f57d09..23339abe5da 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -10,11 +10,19 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))] use tokio::test as maybe_tokio_test; -use tokio::sync::mpsc; +use tokio::join; use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; +use tokio::sync::mpsc::{self, channel}; +use tokio::sync::oneshot; +use tokio::time; use tokio_test::*; +use std::sync::atomic::{ + AtomicUsize, + Ordering::{Acquire, Release}, +}; use std::sync::Arc; +use std::time::Duration; #[cfg(not(target_arch = "wasm32"))] mod support { @@ -657,3 +665,213 @@ fn recv_timeout_panic() { let (tx, _rx) = mpsc::channel(5); tx.send_timeout(10, Duration::from_secs(1)).now_or_never(); } + +#[tokio::test] +async fn weak_sender() { + let (tx, mut rx) = channel(11); + let tx_weak = tx.clone().downgrade(); + + let tx_weak = tokio::spawn(async move { + for i in 0..10 { + if tx.send(i).await.is_err() { + return None; + } + } + + let tx2 = tx_weak + .upgrade() + .expect("expected to be able to upgrade tx_weak"); + let _ = tx2.send(20).await; + let tx_weak = tx2.downgrade(); + + Some(tx_weak) + }) + .await + .unwrap(); + + for i in 0..12 { + let recvd = rx.recv().await; + + match recvd { + Some(msg) => { + if i == 10 { + assert_eq!(msg, 20); + } + } + None => { + assert_eq!(i, 11); + break; + } + } + } + + if let Some(tx_weak) = tx_weak { + let upgraded = tx_weak.upgrade(); + assert!(upgraded.is_none()); + } +} + +#[tokio::test] +async fn actor_weak_sender() { + pub struct MyActor { + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + next_id: u32, + pub received_self_msg: bool, + } + + enum ActorMessage { + GetUniqueId { respond_to: oneshot::Sender }, + SelfMessage {}, + } + + impl MyActor { + fn new( + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + ) -> Self { + MyActor { + receiver, + sender, + next_id: 0, + received_self_msg: false, + } + } + + fn handle_message(&mut self, msg: ActorMessage) { + match msg { + ActorMessage::GetUniqueId { respond_to } => { + self.next_id += 1; + + // The `let _ =` ignores any errors when sending. + // + // This can happen if the `select!` macro is used + // to cancel waiting for the response. + let _ = respond_to.send(self.next_id); + } + ActorMessage::SelfMessage { .. } => { + self.received_self_msg = true; + } + } + } + + async fn send_message_to_self(&mut self) { + let msg = ActorMessage::SelfMessage {}; + + let sender = self.sender.clone(); + + // cannot move self.sender here + if let Some(sender) = sender.upgrade() { + let _ = sender.send(msg).await; + self.sender = sender.downgrade(); + } + } + + async fn run(&mut self) { + let mut i = 0; + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg); + + if i == 0 { + self.send_message_to_self().await; + } + + i += 1 + } + + assert!(self.received_self_msg); + } + } + + #[derive(Clone)] + pub struct MyActorHandle { + sender: mpsc::Sender, + } + + impl MyActorHandle { + pub fn new() -> (Self, MyActor) { + let (sender, receiver) = mpsc::channel(8); + let actor = MyActor::new(receiver, sender.clone().downgrade()); + + (Self { sender }, actor) + } + + pub async fn get_unique_id(&self) -> u32 { + let (send, recv) = oneshot::channel(); + let msg = ActorMessage::GetUniqueId { respond_to: send }; + + // Ignore send errors. If this send fails, so does the + // recv.await below. There's no reason to check the + // failure twice. + let _ = self.sender.send(msg).await; + recv.await.expect("Actor task has been killed") + } + } + + let (handle, mut actor) = MyActorHandle::new(); + + let actor_handle = tokio::spawn(async move { actor.run().await }); + + let _ = tokio::spawn(async move { + let _ = handle.get_unique_id().await; + drop(handle); + }) + .await; + + let _ = actor_handle.await; +} + +static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug)] +struct Msg; + +impl Drop for Msg { + fn drop(&mut self) { + NUM_DROPPED.fetch_add(1, Release); + } +} + +// Tests that no pending messages are put onto the channel after `Rx` was +// dropped. +// +// Note: After the introduction of `WeakSender`, which internally +// used `Arc` and doesn't call a drop of the channel after the last strong +// `Sender` was dropped while more than one `WeakSender` remains, we want to +// ensure that no messages are kept in the channel, which were sent after +// the receiver was dropped. +#[tokio::test(start_paused = true)] +async fn test_msgs_dropped_on_rx_drop() { + fn ms(millis: u64) -> Duration { + Duration::from_millis(millis) + } + + let (tx, mut rx) = mpsc::channel(3); + + let rx_handle = tokio::spawn(async move { + let _ = rx.recv().await.unwrap(); + let _ = rx.recv().await.unwrap(); + time::sleep(ms(1)).await; + drop(rx); + + time::advance(ms(1)).await; + }); + + let tx_handle = tokio::spawn(async move { + let _ = tx.send(Msg {}).await.unwrap(); + let _ = tx.send(Msg {}).await.unwrap(); + + // This msg will be pending and should be dropped when `rx` is dropped + let _ = tx.send(Msg {}).await.unwrap(); + time::advance(ms(1)).await; + + // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. + time::sleep(ms(1)).await; + let _ = tx.send(Msg {}).await.unwrap(); + + // Ensure that third message isn't put onto the channel anymore + assert_eq!(NUM_DROPPED.load(Acquire), 4); + }); + + let (_, _) = join!(rx_handle, tx_handle); +}