From 19e63e4a17be80bc4fedae34b6070c3726341c42 Mon Sep 17 00:00:00 2001 From: b-naber Date: Thu, 24 Mar 2022 10:08:22 +0100 Subject: [PATCH 01/13] Implement Weak version of mpsc::Sender --- tokio/src/sync/mpsc/bounded.rs | 71 ++++++++++++++++++++++++++++++++++ tokio/src/sync/mpsc/chan.rs | 69 +++++++++++++++++++++++++++++++-- tokio/src/sync/mpsc/mod.rs | 2 +- 3 files changed, 138 insertions(+), 4 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index c2a2f061872..b22228a767f 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -22,6 +22,50 @@ pub struct Sender { chan: chan::Tx, } +/// A Sender that does not influence RAII semantics, i.e. if all [`Sender`] +/// instances of a channel were dropped and only `WeakSender` instances remain, +/// the channel is closed. +/// +/// In order to send messages, the `WeakSender` needs to be upgraded using +/// [`WeakSender::upgrade`], which returns `Option`, `None` if all +/// `Sender`s were already dropped, otherwise `Some` (at which point it does +/// influence RAII semantics again). +/// +/// [`Sender`]: Sender +/// [`WeakSender::upgrade`]: WeakSender::upgrade +/// +/// #Examples +/// +/// ```rust +/// use tokio; +/// use tokio::sync::mpsc::channel; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, mut rx) = channel(15); +/// let _ = tx.send(1).await; +/// let tx_weak = tx.downgrade(); +/// +/// let _ = tokio::spawn(async move { +/// for i in 0..2 { +/// if i == 0 { +/// assert_eq!(rx.recv().await.unwrap(), 1); +/// } else if i == 1 { +/// // only WeakSender instance remains -> channel is dropped +/// assert!(rx.recv().await.is_none()); +/// } +/// } +/// }) +/// .await; +/// +/// assert!(tx_weak.upgrade().is_none()); +/// } +/// +/// ``` +pub struct WeakSender { + chan: chan::TxWeak, +} + /// Permits to send one value into the channel. /// /// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] @@ -991,6 +1035,18 @@ impl Sender { pub fn capacity(&self) -> usize { self.chan.semaphore().0.available_permits() } + + /// Converts the `Sender` to a [`WeakSender`] that does not count + /// towards RAII semantics, i.e. if all `Sender` instances of the + /// channel were dropped and only `WeakSender` instances remain, + /// the channel is closed. + pub fn downgrade(self) -> WeakSender { + // Note: If this is the last `Sender` instance we want to close the + // channel when downgrading, so it's important to move into `self` here. + + let chan = self.chan.downgrade(); + WeakSender { chan } + } } impl Clone for Sender { @@ -1009,6 +1065,21 @@ impl fmt::Debug for Sender { } } +impl WeakSender { + /// Tries to conver a WeakSender into a [`Sender`]. This will return `Some` + /// if there are other `Sender` instances alive and the channel wasn't + /// previously dropped, otherwise `None` is returned. + pub fn upgrade(&self) -> Option> { + self.chan.upgrade().map(Sender::new) + } +} + +impl fmt::Debug for WeakSender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("WeakSender").finish() + } +} + // ===== impl Permit ===== impl Permit<'_, T> { diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 1228cfb6c6f..1b40f4f3a7f 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -1,7 +1,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::Arc; +use crate::loom::sync::{Arc, Weak}; use crate::park::thread::CachedParkThread; use crate::park::Park; use crate::sync::mpsc::error::TryRecvError; @@ -9,8 +9,9 @@ use crate::sync::mpsc::list; use crate::sync::notify::Notify; use std::fmt; +use std::mem; use std::process; -use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, SeqCst}; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; @@ -25,6 +26,16 @@ impl fmt::Debug for Tx { } } +pub(crate) struct TxWeak { + inner: Weak>, +} + +impl fmt::Debug for TxWeak { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("TxWeak").finish() + } +} + /// Channel receiver. pub(crate) struct Rx { inner: Arc>, @@ -129,6 +140,14 @@ impl Tx { Tx { inner: chan } } + pub(super) fn downgrade(self) -> TxWeak { + // We don't decrement the `tx_counter` here, but let the counter be decremented + // through the drop of self.inner. + let weak_inner = Arc::>::downgrade(&self.inner); + + TxWeak::new(weak_inner) + } + pub(super) fn semaphore(&self) -> &S { &self.inner.semaphore } @@ -149,6 +168,50 @@ impl Tx { } } +impl TxWeak { + fn new(inner: Weak>) -> Self { + TxWeak { inner } + } + + pub(super) fn upgrade(&self) -> Option> { + let inner = self.inner.upgrade(); + + if let Some(inner) = inner { + // If we were able to upgrade, `Chan` is guaranteed to still exist, + // even though the channel might have been closed in the meantime. + // Need to check here whether the channel was actually closed. + + let mut tx_count = inner.tx_count.load(Relaxed); + loop { + // FIXME Haven't thought the orderings on the CAS through yet + match inner + .tx_count + .compare_exchange(tx_count, tx_count + 1, SeqCst, SeqCst) + { + Ok(prev_count) => { + if prev_count == 0 { + mem::drop(inner); + return None; + } + + return Some(Tx::new(inner)); + } + Err(count) => { + if count == 0 { + mem::drop(inner); + return None; + } + + tx_count = count; + } + } + } + } else { + None + } + } +} + impl Tx { pub(crate) fn is_closed(&self) -> bool { self.inner.semaphore.is_closed() @@ -378,7 +441,7 @@ impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { // ===== impl Semaphore for AtomicUsize ===== -use std::sync::atomic::Ordering::{Acquire, Release}; +use std::sync::atomic::Ordering::Release; use std::usize; impl Semaphore for AtomicUsize { diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index b1513a9da51..d37779ca761 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -90,7 +90,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender}; +pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender}; mod chan; From a0d94a104cb5a45f103037bf33aa2d21a0fd037e Mon Sep 17 00:00:00 2001 From: b-naber Date: Fri, 1 Apr 2022 11:51:42 +0200 Subject: [PATCH 02/13] add tests --- tokio-util/tests/mpsc.rs | 160 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 1 deletion(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index a3c164d3eca..66b874975b8 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -1,5 +1,6 @@ use futures::future::poll_fn; -use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::{self, channel}; +use tokio::sync::oneshot; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; @@ -237,3 +238,160 @@ fn start_send_panics_when_acquiring() { assert_pending!(reserve.poll()); send.send_item(2).unwrap(); } + +#[tokio::test] +async fn weak_sender() { + let (tx, mut rx) = channel(11); + let tx_weak = tx.clone().downgrade(); + + let tx_weak = tokio::spawn(async move { + for i in 0..10 { + if let Err(_) = tx.send(i).await { + return None; + } + } + + let tx2 = tx_weak + .upgrade() + .expect("expected to be able to upgrade tx_weak"); + let _ = tx2.send(20).await; + let tx_weak = tx2.downgrade(); + + Some(tx_weak) + }) + .await + .unwrap(); + + for i in 0..12 { + let recvd = rx.recv().await; + + match recvd { + Some(msg) => { + if i == 10 { + assert_eq!(msg, 20); + } + } + None => { + assert_eq!(i, 11); + break; + } + } + } + + if let Some(tx_weak) = tx_weak { + let upgraded = tx_weak.upgrade(); + assert!(upgraded.is_none()); + } +} + +#[tokio::test] +async fn actor_weak_sender() { + pub struct MyActor { + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + next_id: u32, + pub received_self_msg: bool, + } + + enum ActorMessage { + GetUniqueId { respond_to: oneshot::Sender }, + SelfMessage {}, + } + + impl MyActor { + fn new( + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + ) -> Self { + MyActor { + receiver, + sender, + next_id: 0, + received_self_msg: false, + } + } + + fn handle_message(&mut self, msg: ActorMessage) { + match msg { + ActorMessage::GetUniqueId { respond_to } => { + self.next_id += 1; + + // The `let _ =` ignores any errors when sending. + // + // This can happen if the `select!` macro is used + // to cancel waiting for the response. + let _ = respond_to.send(self.next_id); + } + ActorMessage::SelfMessage { .. } => { + self.received_self_msg = true; + } + } + } + + async fn send_message_to_self(&mut self) { + let msg = ActorMessage::SelfMessage {}; + + if let Some(sender) = self.sender.upgrade() { + let _ = sender.send(msg).await; + self.sender = sender.downgrade(); + } + } + + async fn run(&mut self) { + let mut i = 0; + loop { + match self.receiver.recv().await { + Some(msg) => { + self.handle_message(msg); + } + None => { + break; + } + } + if i == 0 { + self.send_message_to_self().await; + } + i += 1; + } + + assert!(self.received_self_msg); + } + } + + #[derive(Clone)] + pub struct MyActorHandle { + sender: mpsc::Sender, + } + + impl MyActorHandle { + pub fn new() -> (Self, MyActor) { + let (sender, receiver) = mpsc::channel(8); + let actor = MyActor::new(receiver, sender.clone().downgrade()); + + (Self { sender }, actor) + } + + pub async fn get_unique_id(&self) -> u32 { + let (send, recv) = oneshot::channel(); + let msg = ActorMessage::GetUniqueId { respond_to: send }; + + // Ignore send errors. If this send fails, so does the + // recv.await below. There's no reason to check the + // failure twice. + let _ = self.sender.send(msg).await; + recv.await.expect("Actor task has been killed") + } + } + + let (handle, mut actor) = MyActorHandle::new(); + + let actor_handle = tokio::spawn(async move { actor.run().await }); + + let _ = tokio::spawn(async move { + let _ = handle.get_unique_id().await; + drop(handle); + }) + .await; + + let _ = actor_handle.await; +} From 94eb03e00b0c3b4cfd285f12f6053e69b0a6e45a Mon Sep 17 00:00:00 2001 From: b-naber Date: Fri, 8 Apr 2022 12:01:03 +0200 Subject: [PATCH 03/13] address review and fix clippy failure --- tokio-util/tests/mpsc.rs | 17 ++++++----------- tokio/src/sync/mpsc/chan.rs | 32 +++++++++++++++++--------------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index 66b874975b8..a0e21066bd7 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -246,7 +246,7 @@ async fn weak_sender() { let tx_weak = tokio::spawn(async move { for i in 0..10 { - if let Err(_) = tx.send(i).await { + if tx.send(i).await.is_err() { return None; } } @@ -339,19 +339,14 @@ async fn actor_weak_sender() { async fn run(&mut self) { let mut i = 0; - loop { - match self.receiver.recv().await { - Some(msg) => { - self.handle_message(msg); - } - None => { - break; - } - } + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg); + if i == 0 { self.send_message_to_self().await; } - i += 1; + + i += 1 } assert!(self.received_self_msg); diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 1b40f4f3a7f..d2394ced899 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -1,7 +1,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Weak}; +use crate::loom::sync::Arc; use crate::park::thread::CachedParkThread; use crate::park::Park; use crate::sync::mpsc::error::TryRecvError; @@ -11,9 +11,11 @@ use crate::sync::notify::Notify; use std::fmt; use std::mem; use std::process; -use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, SeqCst}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; +use std::sync::Weak; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; +use std::usize; /// Channel sender. pub(crate) struct Tx { @@ -181,28 +183,31 @@ impl TxWeak { // even though the channel might have been closed in the meantime. // Need to check here whether the channel was actually closed. - let mut tx_count = inner.tx_count.load(Relaxed); + let mut tx_count = inner.tx_count.load(Acquire); + + if tx_count == 0 { + // channel is closed + mem::drop(inner); + return None; + } + loop { - // FIXME Haven't thought the orderings on the CAS through yet match inner .tx_count - .compare_exchange(tx_count, tx_count + 1, SeqCst, SeqCst) + .compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire) { Ok(prev_count) => { - if prev_count == 0 { - mem::drop(inner); - return None; - } + assert!(prev_count != 0); return Some(Tx::new(inner)); } - Err(count) => { - if count == 0 { + Err(prev_count) => { + if prev_count == 0 { mem::drop(inner); return None; } - tx_count = count; + tx_count = prev_count; } } } @@ -441,9 +446,6 @@ impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { // ===== impl Semaphore for AtomicUsize ===== -use std::sync::atomic::Ordering::Release; -use std::usize; - impl Semaphore for AtomicUsize { fn add_permit(&self) { let prev = self.fetch_sub(2, Release); From a01ad3986852896bab29917a16100ac15c49b4c2 Mon Sep 17 00:00:00 2001 From: b-naber Date: Tue, 10 May 2022 16:05:35 +0200 Subject: [PATCH 04/13] use Arc in WeakSender --- tokio-util/tests/mpsc.rs | 5 +- tokio/src/sync/mpsc/bounded.rs | 21 +++++-- tokio/src/sync/mpsc/chan.rs | 104 +++++++++++++-------------------- 3 files changed, 59 insertions(+), 71 deletions(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index a0e21066bd7..872964a41f7 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -331,7 +331,10 @@ async fn actor_weak_sender() { async fn send_message_to_self(&mut self) { let msg = ActorMessage::SelfMessage {}; - if let Some(sender) = self.sender.upgrade() { + let sender = self.sender.clone(); + + // cannot move self.sender here + if let Some(sender) = sender.upgrade() { let _ = sender.send(msg).await; self.sender = sender.downgrade(); } diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index b22228a767f..92959ac3a2a 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -63,7 +63,7 @@ pub struct Sender { /// /// ``` pub struct WeakSender { - chan: chan::TxWeak, + chan: chan::Tx, } /// Permits to send one value into the channel. @@ -1044,8 +1044,9 @@ impl Sender { // Note: If this is the last `Sender` instance we want to close the // channel when downgrading, so it's important to move into `self` here. - let chan = self.chan.downgrade(); - WeakSender { chan } + WeakSender { + chan: self.chan.downgrade(), + } } } @@ -1065,12 +1066,20 @@ impl fmt::Debug for Sender { } } +impl Clone for WeakSender { + fn clone(&self) -> Self { + WeakSender { + chan: self.chan.clone(), + } + } +} + impl WeakSender { - /// Tries to conver a WeakSender into a [`Sender`]. This will return `Some` + /// Tries to convert a WeakSender into a [`Sender`]. This will return `Some` /// if there are other `Sender` instances alive and the channel wasn't /// previously dropped, otherwise `None` is returned. - pub fn upgrade(&self) -> Option> { - self.chan.upgrade().map(Sender::new) + pub fn upgrade(self) -> Option> { + self.chan.upgrade().map(|tx| Sender::new(tx)) } } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index d2394ced899..d097ce1e4a7 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -9,10 +9,8 @@ use crate::sync::mpsc::list; use crate::sync::notify::Notify; use std::fmt; -use std::mem; use std::process; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; -use std::sync::Weak; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; use std::usize; @@ -28,16 +26,6 @@ impl fmt::Debug for Tx { } } -pub(crate) struct TxWeak { - inner: Weak>, -} - -impl fmt::Debug for TxWeak { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("TxWeak").finish() - } -} - /// Channel receiver. pub(crate) struct Rx { inner: Arc>, @@ -142,12 +130,47 @@ impl Tx { Tx { inner: chan } } - pub(super) fn downgrade(self) -> TxWeak { - // We don't decrement the `tx_counter` here, but let the counter be decremented - // through the drop of self.inner. - let weak_inner = Arc::>::downgrade(&self.inner); + pub(super) fn downgrade(self) -> Self { + if self.inner.tx_count.fetch_sub(1, AcqRel) == 1 { + // Close the list, which sends a `Close` message + self.inner.tx.close(); + + // Notify the receiver + self.wake_rx(); + } + + self + } - TxWeak::new(weak_inner) + // Returns a boolean that indicates whether the channel is closed. + pub(super) fn upgrade(self) -> Option { + let mut tx_count = self.inner.tx_count.load(Acquire); + + if tx_count == 0 { + // channel is closed + return None; + } + + loop { + match self + .inner + .tx_count + .compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire) + { + Ok(prev_count) => { + assert!(prev_count != 0); + + return Some(self); + } + Err(prev_count) => { + if prev_count == 0 { + return None; + } + + tx_count = prev_count; + } + } + } } pub(super) fn semaphore(&self) -> &S { @@ -170,53 +193,6 @@ impl Tx { } } -impl TxWeak { - fn new(inner: Weak>) -> Self { - TxWeak { inner } - } - - pub(super) fn upgrade(&self) -> Option> { - let inner = self.inner.upgrade(); - - if let Some(inner) = inner { - // If we were able to upgrade, `Chan` is guaranteed to still exist, - // even though the channel might have been closed in the meantime. - // Need to check here whether the channel was actually closed. - - let mut tx_count = inner.tx_count.load(Acquire); - - if tx_count == 0 { - // channel is closed - mem::drop(inner); - return None; - } - - loop { - match inner - .tx_count - .compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire) - { - Ok(prev_count) => { - assert!(prev_count != 0); - - return Some(Tx::new(inner)); - } - Err(prev_count) => { - if prev_count == 0 { - mem::drop(inner); - return None; - } - - tx_count = prev_count; - } - } - } - } else { - None - } - } -} - impl Tx { pub(crate) fn is_closed(&self) -> bool { self.inner.semaphore.is_closed() From 5006259f843e460b7d7e0899a7296f5ebbb92d70 Mon Sep 17 00:00:00 2001 From: b-naber Date: Wed, 6 Jul 2022 18:38:46 +0200 Subject: [PATCH 05/13] add test to ensure that no pending msgs are kept in the channel after rx was dropped --- tokio-util/tests/mpsc.rs | 63 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index 872964a41f7..7f7ccd2ece4 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -1,6 +1,14 @@ use futures::future::poll_fn; +use std::ops::Drop; +use std::sync::atomic::{ + AtomicUsize, + Ordering::{Acquire, Release}, +}; +use std::time::Duration; +use tokio::join; use tokio::sync::mpsc::{self, channel}; use tokio::sync::oneshot; +use tokio::time; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; @@ -393,3 +401,58 @@ async fn actor_weak_sender() { let _ = actor_handle.await; } + +static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug)] +struct Msg; + +impl Drop for Msg { + fn drop(&mut self) { + NUM_DROPPED.fetch_add(1, Release); + } +} + +// Tests that no pending messages are put onto the channel after `Rx` was +// dropped. +// +// Note: After the introduction of `WeakSender`, which internally +// used `Arc` and doesn't call a drop of the channel after the last strong +// `Sender` was dropped while more than one `WeakSender` remains, we want to +// ensure that no messages are kept in the channel, which were sent after +// the receiver was dropped. +#[tokio::test(start_paused = true)] +async fn test_msgs_dropped_on_rx_drop() { + fn ms(millis: u64) -> Duration { + Duration::from_millis(millis) + } + + let (tx, mut rx) = mpsc::channel(3); + + let rx_handle = tokio::spawn(async move { + let _ = rx.recv().await.unwrap(); + let _ = rx.recv().await.unwrap(); + time::sleep(ms(1)).await; + drop(rx); + + time::advance(ms(1)).await; + }); + + let tx_handle = tokio::spawn(async move { + let _ = tx.send(Msg {}).await.unwrap(); + let _ = tx.send(Msg {}).await.unwrap(); + + // This msg will be pending and should be dropped when `rx` is dropped + let _ = tx.send(Msg {}).await.unwrap(); + time::advance(ms(1)).await; + + // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. + time::sleep(ms(1)).await; + let _ = tx.send(Msg {}).await.unwrap(); + + // Ensure that third message isn't put onto the channel anymore + assert_eq!(NUM_DROPPED.load(Acquire), 4); + }); + + let (_, _) = join!(rx_handle, tx_handle); +} From 3a95b5ab525117c516539b670e9f6622d5ca3b10 Mon Sep 17 00:00:00 2001 From: b-naber Date: Mon, 18 Jul 2022 10:30:03 +0200 Subject: [PATCH 06/13] address review --- tokio-util/tests/mpsc.rs | 222 +-------------------------------- tokio/src/sync/mpsc/bounded.rs | 2 +- tokio/src/sync/mpsc/chan.rs | 26 ++-- tokio/tests/sync_mpsc.rs | 220 +++++++++++++++++++++++++++++++- 4 files changed, 230 insertions(+), 240 deletions(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index 7f7ccd2ece4..4e4d0090760 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -1,14 +1,6 @@ use futures::future::poll_fn; -use std::ops::Drop; -use std::sync::atomic::{ - AtomicUsize, - Ordering::{Acquire, Release}, -}; -use std::time::Duration; -use tokio::join; -use tokio::sync::mpsc::{self, channel}; -use tokio::sync::oneshot; -use tokio::time; + +use tokio::sync::mpsc::channel; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; @@ -246,213 +238,3 @@ fn start_send_panics_when_acquiring() { assert_pending!(reserve.poll()); send.send_item(2).unwrap(); } - -#[tokio::test] -async fn weak_sender() { - let (tx, mut rx) = channel(11); - let tx_weak = tx.clone().downgrade(); - - let tx_weak = tokio::spawn(async move { - for i in 0..10 { - if tx.send(i).await.is_err() { - return None; - } - } - - let tx2 = tx_weak - .upgrade() - .expect("expected to be able to upgrade tx_weak"); - let _ = tx2.send(20).await; - let tx_weak = tx2.downgrade(); - - Some(tx_weak) - }) - .await - .unwrap(); - - for i in 0..12 { - let recvd = rx.recv().await; - - match recvd { - Some(msg) => { - if i == 10 { - assert_eq!(msg, 20); - } - } - None => { - assert_eq!(i, 11); - break; - } - } - } - - if let Some(tx_weak) = tx_weak { - let upgraded = tx_weak.upgrade(); - assert!(upgraded.is_none()); - } -} - -#[tokio::test] -async fn actor_weak_sender() { - pub struct MyActor { - receiver: mpsc::Receiver, - sender: mpsc::WeakSender, - next_id: u32, - pub received_self_msg: bool, - } - - enum ActorMessage { - GetUniqueId { respond_to: oneshot::Sender }, - SelfMessage {}, - } - - impl MyActor { - fn new( - receiver: mpsc::Receiver, - sender: mpsc::WeakSender, - ) -> Self { - MyActor { - receiver, - sender, - next_id: 0, - received_self_msg: false, - } - } - - fn handle_message(&mut self, msg: ActorMessage) { - match msg { - ActorMessage::GetUniqueId { respond_to } => { - self.next_id += 1; - - // The `let _ =` ignores any errors when sending. - // - // This can happen if the `select!` macro is used - // to cancel waiting for the response. - let _ = respond_to.send(self.next_id); - } - ActorMessage::SelfMessage { .. } => { - self.received_self_msg = true; - } - } - } - - async fn send_message_to_self(&mut self) { - let msg = ActorMessage::SelfMessage {}; - - let sender = self.sender.clone(); - - // cannot move self.sender here - if let Some(sender) = sender.upgrade() { - let _ = sender.send(msg).await; - self.sender = sender.downgrade(); - } - } - - async fn run(&mut self) { - let mut i = 0; - while let Some(msg) = self.receiver.recv().await { - self.handle_message(msg); - - if i == 0 { - self.send_message_to_self().await; - } - - i += 1 - } - - assert!(self.received_self_msg); - } - } - - #[derive(Clone)] - pub struct MyActorHandle { - sender: mpsc::Sender, - } - - impl MyActorHandle { - pub fn new() -> (Self, MyActor) { - let (sender, receiver) = mpsc::channel(8); - let actor = MyActor::new(receiver, sender.clone().downgrade()); - - (Self { sender }, actor) - } - - pub async fn get_unique_id(&self) -> u32 { - let (send, recv) = oneshot::channel(); - let msg = ActorMessage::GetUniqueId { respond_to: send }; - - // Ignore send errors. If this send fails, so does the - // recv.await below. There's no reason to check the - // failure twice. - let _ = self.sender.send(msg).await; - recv.await.expect("Actor task has been killed") - } - } - - let (handle, mut actor) = MyActorHandle::new(); - - let actor_handle = tokio::spawn(async move { actor.run().await }); - - let _ = tokio::spawn(async move { - let _ = handle.get_unique_id().await; - drop(handle); - }) - .await; - - let _ = actor_handle.await; -} - -static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0); - -#[derive(Debug)] -struct Msg; - -impl Drop for Msg { - fn drop(&mut self) { - NUM_DROPPED.fetch_add(1, Release); - } -} - -// Tests that no pending messages are put onto the channel after `Rx` was -// dropped. -// -// Note: After the introduction of `WeakSender`, which internally -// used `Arc` and doesn't call a drop of the channel after the last strong -// `Sender` was dropped while more than one `WeakSender` remains, we want to -// ensure that no messages are kept in the channel, which were sent after -// the receiver was dropped. -#[tokio::test(start_paused = true)] -async fn test_msgs_dropped_on_rx_drop() { - fn ms(millis: u64) -> Duration { - Duration::from_millis(millis) - } - - let (tx, mut rx) = mpsc::channel(3); - - let rx_handle = tokio::spawn(async move { - let _ = rx.recv().await.unwrap(); - let _ = rx.recv().await.unwrap(); - time::sleep(ms(1)).await; - drop(rx); - - time::advance(ms(1)).await; - }); - - let tx_handle = tokio::spawn(async move { - let _ = tx.send(Msg {}).await.unwrap(); - let _ = tx.send(Msg {}).await.unwrap(); - - // This msg will be pending and should be dropped when `rx` is dropped - let _ = tx.send(Msg {}).await.unwrap(); - time::advance(ms(1)).await; - - // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. - time::sleep(ms(1)).await; - let _ = tx.send(Msg {}).await.unwrap(); - - // Ensure that third message isn't put onto the channel anymore - assert_eq!(NUM_DROPPED.load(Acquire), 4); - }); - - let (_, _) = join!(rx_handle, tx_handle); -} diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 92959ac3a2a..5889d8d9215 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1079,7 +1079,7 @@ impl WeakSender { /// if there are other `Sender` instances alive and the channel wasn't /// previously dropped, otherwise `None` is returned. pub fn upgrade(self) -> Option> { - self.chan.upgrade().map(|tx| Sender::new(tx)) + self.chan.upgrade().map(Sender::new) } } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index d097ce1e4a7..f6e9e0fab24 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -146,29 +146,19 @@ impl Tx { pub(super) fn upgrade(self) -> Option { let mut tx_count = self.inner.tx_count.load(Acquire); - if tx_count == 0 { - // channel is closed - return None; - } - loop { + if tx_count == 0 { + // channel is closed + return None; + } + match self .inner .tx_count - .compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire) + .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire) { - Ok(prev_count) => { - assert!(prev_count != 0); - - return Some(self); - } - Err(prev_count) => { - if prev_count == 0 { - return None; - } - - tx_count = prev_count; - } + Ok(_) => return Some(self), + Err(prev_count) => tx_count = prev_count, } } } diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index a1510f57d09..23339abe5da 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -10,11 +10,19 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))] use tokio::test as maybe_tokio_test; -use tokio::sync::mpsc; +use tokio::join; use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; +use tokio::sync::mpsc::{self, channel}; +use tokio::sync::oneshot; +use tokio::time; use tokio_test::*; +use std::sync::atomic::{ + AtomicUsize, + Ordering::{Acquire, Release}, +}; use std::sync::Arc; +use std::time::Duration; #[cfg(not(target_arch = "wasm32"))] mod support { @@ -657,3 +665,213 @@ fn recv_timeout_panic() { let (tx, _rx) = mpsc::channel(5); tx.send_timeout(10, Duration::from_secs(1)).now_or_never(); } + +#[tokio::test] +async fn weak_sender() { + let (tx, mut rx) = channel(11); + let tx_weak = tx.clone().downgrade(); + + let tx_weak = tokio::spawn(async move { + for i in 0..10 { + if tx.send(i).await.is_err() { + return None; + } + } + + let tx2 = tx_weak + .upgrade() + .expect("expected to be able to upgrade tx_weak"); + let _ = tx2.send(20).await; + let tx_weak = tx2.downgrade(); + + Some(tx_weak) + }) + .await + .unwrap(); + + for i in 0..12 { + let recvd = rx.recv().await; + + match recvd { + Some(msg) => { + if i == 10 { + assert_eq!(msg, 20); + } + } + None => { + assert_eq!(i, 11); + break; + } + } + } + + if let Some(tx_weak) = tx_weak { + let upgraded = tx_weak.upgrade(); + assert!(upgraded.is_none()); + } +} + +#[tokio::test] +async fn actor_weak_sender() { + pub struct MyActor { + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + next_id: u32, + pub received_self_msg: bool, + } + + enum ActorMessage { + GetUniqueId { respond_to: oneshot::Sender }, + SelfMessage {}, + } + + impl MyActor { + fn new( + receiver: mpsc::Receiver, + sender: mpsc::WeakSender, + ) -> Self { + MyActor { + receiver, + sender, + next_id: 0, + received_self_msg: false, + } + } + + fn handle_message(&mut self, msg: ActorMessage) { + match msg { + ActorMessage::GetUniqueId { respond_to } => { + self.next_id += 1; + + // The `let _ =` ignores any errors when sending. + // + // This can happen if the `select!` macro is used + // to cancel waiting for the response. + let _ = respond_to.send(self.next_id); + } + ActorMessage::SelfMessage { .. } => { + self.received_self_msg = true; + } + } + } + + async fn send_message_to_self(&mut self) { + let msg = ActorMessage::SelfMessage {}; + + let sender = self.sender.clone(); + + // cannot move self.sender here + if let Some(sender) = sender.upgrade() { + let _ = sender.send(msg).await; + self.sender = sender.downgrade(); + } + } + + async fn run(&mut self) { + let mut i = 0; + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg); + + if i == 0 { + self.send_message_to_self().await; + } + + i += 1 + } + + assert!(self.received_self_msg); + } + } + + #[derive(Clone)] + pub struct MyActorHandle { + sender: mpsc::Sender, + } + + impl MyActorHandle { + pub fn new() -> (Self, MyActor) { + let (sender, receiver) = mpsc::channel(8); + let actor = MyActor::new(receiver, sender.clone().downgrade()); + + (Self { sender }, actor) + } + + pub async fn get_unique_id(&self) -> u32 { + let (send, recv) = oneshot::channel(); + let msg = ActorMessage::GetUniqueId { respond_to: send }; + + // Ignore send errors. If this send fails, so does the + // recv.await below. There's no reason to check the + // failure twice. + let _ = self.sender.send(msg).await; + recv.await.expect("Actor task has been killed") + } + } + + let (handle, mut actor) = MyActorHandle::new(); + + let actor_handle = tokio::spawn(async move { actor.run().await }); + + let _ = tokio::spawn(async move { + let _ = handle.get_unique_id().await; + drop(handle); + }) + .await; + + let _ = actor_handle.await; +} + +static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug)] +struct Msg; + +impl Drop for Msg { + fn drop(&mut self) { + NUM_DROPPED.fetch_add(1, Release); + } +} + +// Tests that no pending messages are put onto the channel after `Rx` was +// dropped. +// +// Note: After the introduction of `WeakSender`, which internally +// used `Arc` and doesn't call a drop of the channel after the last strong +// `Sender` was dropped while more than one `WeakSender` remains, we want to +// ensure that no messages are kept in the channel, which were sent after +// the receiver was dropped. +#[tokio::test(start_paused = true)] +async fn test_msgs_dropped_on_rx_drop() { + fn ms(millis: u64) -> Duration { + Duration::from_millis(millis) + } + + let (tx, mut rx) = mpsc::channel(3); + + let rx_handle = tokio::spawn(async move { + let _ = rx.recv().await.unwrap(); + let _ = rx.recv().await.unwrap(); + time::sleep(ms(1)).await; + drop(rx); + + time::advance(ms(1)).await; + }); + + let tx_handle = tokio::spawn(async move { + let _ = tx.send(Msg {}).await.unwrap(); + let _ = tx.send(Msg {}).await.unwrap(); + + // This msg will be pending and should be dropped when `rx` is dropped + let _ = tx.send(Msg {}).await.unwrap(); + time::advance(ms(1)).await; + + // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. + time::sleep(ms(1)).await; + let _ = tx.send(Msg {}).await.unwrap(); + + // Ensure that third message isn't put onto the channel anymore + assert_eq!(NUM_DROPPED.load(Acquire), 4); + }); + + let (_, _) = join!(rx_handle, tx_handle); +} From f37d0130dd95311c913a0ef2ca551c710f576b69 Mon Sep 17 00:00:00 2001 From: b-naber Date: Mon, 18 Jul 2022 16:53:31 +0200 Subject: [PATCH 07/13] add and update tests --- tokio/tests/sync_mpsc.rs | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index 23339abe5da..db62bb312cd 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -867,7 +867,7 @@ async fn test_msgs_dropped_on_rx_drop() { // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. time::sleep(ms(1)).await; - let _ = tx.send(Msg {}).await.unwrap(); + assert!(tx.send(Msg {}).await.is_err()); // Ensure that third message isn't put onto the channel anymore assert_eq!(NUM_DROPPED.load(Acquire), 4); @@ -875,3 +875,30 @@ async fn test_msgs_dropped_on_rx_drop() { let (_, _) = join!(rx_handle, tx_handle); } + +#[tokio::test] +// Tests that a `WeakSender` is upgradeable when other `Sender`s exist. +async fn downgrade_upgrade_sender_success() { + let (tx, _rx) = mpsc::channel::(1); + let weak_tx = tx.clone().downgrade(); + assert!(weak_tx.upgrade().is_some()); +} + +#[tokio::test] +// Tests that a `WeakSender` fails to upgrade when no other `Sender` exists. +async fn downgrade_upgrade_sender_failure() { + let (tx, _rx) = mpsc::channel::(1); + let weak_tx = tx.downgrade(); + assert!(weak_tx.upgrade().is_none()); +} + +#[tokio::test] +// Tests that a `WeakSender` cannot be upgraded after a `Sender` was dropped, +// which existed at the time of the `downgrade` call. +async fn downgrade_drop_upgrade() { + let (tx, _rx) = mpsc::channel::(1); + + let weak_tx = tx.clone().downgrade(); + drop(tx); + assert!(weak_tx.upgrade().is_none()); +} From b76aeb71a1a5501312f362793ef967f00d7ce0b0 Mon Sep 17 00:00:00 2001 From: b-naber Date: Wed, 20 Jul 2022 17:08:58 +0200 Subject: [PATCH 08/13] address other comments --- tokio-util/tests/mpsc.rs | 1 - tokio/tests/sync_mpsc.rs | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tokio-util/tests/mpsc.rs b/tokio-util/tests/mpsc.rs index 4e4d0090760..a3c164d3eca 100644 --- a/tokio-util/tests/mpsc.rs +++ b/tokio-util/tests/mpsc.rs @@ -1,5 +1,4 @@ use futures::future::poll_fn; - use tokio::sync::mpsc::channel; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index db62bb312cd..5229ebc38b7 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -17,10 +17,8 @@ use tokio::sync::oneshot; use tokio::time; use tokio_test::*; -use std::sync::atomic::{ - AtomicUsize, - Ordering::{Acquire, Release}, -}; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::{Acquire, Release}; use std::sync::Arc; use std::time::Duration; From 631fc0340d2a16f53aaeb8c21f79e7d389f51c95 Mon Sep 17 00:00:00 2001 From: b-naber Date: Thu, 21 Jul 2022 09:35:27 +0200 Subject: [PATCH 09/13] add permit tests --- tokio/tests/sync_mpsc.rs | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index 5229ebc38b7..9be51fc1d03 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -874,25 +874,25 @@ async fn test_msgs_dropped_on_rx_drop() { let (_, _) = join!(rx_handle, tx_handle); } -#[tokio::test] // Tests that a `WeakSender` is upgradeable when other `Sender`s exist. +#[tokio::test] async fn downgrade_upgrade_sender_success() { let (tx, _rx) = mpsc::channel::(1); let weak_tx = tx.clone().downgrade(); assert!(weak_tx.upgrade().is_some()); } -#[tokio::test] // Tests that a `WeakSender` fails to upgrade when no other `Sender` exists. +#[tokio::test] async fn downgrade_upgrade_sender_failure() { let (tx, _rx) = mpsc::channel::(1); let weak_tx = tx.downgrade(); assert!(weak_tx.upgrade().is_none()); } -#[tokio::test] // Tests that a `WeakSender` cannot be upgraded after a `Sender` was dropped, // which existed at the time of the `downgrade` call. +#[tokio::test] async fn downgrade_drop_upgrade() { let (tx, _rx) = mpsc::channel::(1); @@ -900,3 +900,24 @@ async fn downgrade_drop_upgrade() { drop(tx); assert!(weak_tx.upgrade().is_none()); } + +// Tests that we can upgrade a weak sender with an outstanding permit +// but no other strong senders. +#[tokio::test] +async fn downgrade_get_permit_upgrade_no_senders() { + let (tx, _rx) = mpsc::channel::(1); + let weak_tx = tx.clone().downgrade(); + let _permit = tx.reserve_owned().await.unwrap(); + assert!(weak_tx.upgrade().is_some()); +} + +// Tests that you can downgrade and upgrade a sender with an outstanding permit +// but no other senders left. +#[tokio::test] +async fn downgrade_upgrade_get_permit_no_senders() { + let (tx, _rx) = mpsc::channel::(1); + let tx2 = tx.clone(); + let _permit = tx.reserve_owned().await.unwrap(); + let weak_tx = tx2.downgrade(); + assert!(weak_tx.upgrade().is_some()); +} From ef6fabbb82ad0cad1bc197d3e6769893f1a1fe03 Mon Sep 17 00:00:00 2001 From: b-naber Date: Fri, 22 Jul 2022 17:47:01 +0200 Subject: [PATCH 10/13] docs --- tokio/src/sync/mpsc/bounded.rs | 41 +++++++++++++--------------------- tokio/src/sync/mpsc/chan.rs | 2 +- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 5889d8d9215..e82317d699d 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -22,43 +22,34 @@ pub struct Sender { chan: chan::Tx, } -/// A Sender that does not influence RAII semantics, i.e. if all [`Sender`] -/// instances of a channel were dropped and only `WeakSender` instances remain, -/// the channel is closed. +/// A sender that does not prevent the channel from being closed. +/// +/// If all [`Sender`] instances of a channel were dropped and only `WeakSender` +/// instances remain, the channel is closed. /// /// In order to send messages, the `WeakSender` needs to be upgraded using -/// [`WeakSender::upgrade`], which returns `Option`, `None` if all -/// `Sender`s were already dropped, otherwise `Some` (at which point it does -/// influence RAII semantics again). +/// [`WeakSender::upgrade`], which returns `Option`. It returns `None` +/// if all `Sender`s have been dropped, and otherwise it returns a `Sender`. /// /// [`Sender`]: Sender /// [`WeakSender::upgrade`]: WeakSender::upgrade /// /// #Examples /// -/// ```rust -/// use tokio; +/// ``` /// use tokio::sync::mpsc::channel; /// /// #[tokio::main] /// async fn main() { -/// let (tx, mut rx) = channel(15); -/// let _ = tx.send(1).await; -/// let tx_weak = tx.downgrade(); -/// -/// let _ = tokio::spawn(async move { -/// for i in 0..2 { -/// if i == 0 { -/// assert_eq!(rx.recv().await.unwrap(), 1); -/// } else if i == 1 { -/// // only WeakSender instance remains -> channel is dropped -/// assert!(rx.recv().await.is_none()); -/// } -/// } -/// }) -/// .await; -/// -/// assert!(tx_weak.upgrade().is_none()); +/// let (tx, _rx) = channel::(15); +/// let tx_weak = tx.clone().downgrade(); +/// +/// // Upgrading will succeed because `tx` still exists. +/// assert!(tx_weak.clone().upgrade().is_some()); +/// +/// // If we drop `tx`, then it will fail. +/// drop(tx); +/// assert!(tx_weak.clone().upgrade().is_none()); /// } /// /// ``` diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index f6e9e0fab24..9ba55d898a2 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -142,7 +142,7 @@ impl Tx { self } - // Returns a boolean that indicates whether the channel is closed. + // Returns the upgraded channel or None if the upgrade failed. pub(super) fn upgrade(self) -> Option { let mut tx_count = self.inner.tx_count.load(Acquire); From fa8bdf9aab8223ee9fde34d5c34acb0fb4c63202 Mon Sep 17 00:00:00 2001 From: b-naber Date: Tue, 26 Jul 2022 15:24:25 +0200 Subject: [PATCH 11/13] take self as receiver types on downgrade and upgrade --- tokio/src/sync/mpsc/bounded.rs | 9 +++++---- tokio/src/sync/mpsc/chan.rs | 23 +++++++---------------- tokio/tests/sync_mpsc.rs | 22 +++++++++++++++++----- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index e82317d699d..015f01aedf5 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1,3 +1,4 @@ +use crate::loom::sync::Arc; use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; @@ -54,7 +55,7 @@ pub struct Sender { /// /// ``` pub struct WeakSender { - chan: chan::Tx, + chan: Arc>, } /// Permits to send one value into the channel. @@ -1031,7 +1032,7 @@ impl Sender { /// towards RAII semantics, i.e. if all `Sender` instances of the /// channel were dropped and only `WeakSender` instances remain, /// the channel is closed. - pub fn downgrade(self) -> WeakSender { + pub fn downgrade(&self) -> WeakSender { // Note: If this is the last `Sender` instance we want to close the // channel when downgrading, so it's important to move into `self` here. @@ -1069,8 +1070,8 @@ impl WeakSender { /// Tries to convert a WeakSender into a [`Sender`]. This will return `Some` /// if there are other `Sender` instances alive and the channel wasn't /// previously dropped, otherwise `None` is returned. - pub fn upgrade(self) -> Option> { - self.chan.upgrade().map(Sender::new) + pub fn upgrade(&self) -> Option> { + chan::Tx::upgrade(self.chan.clone()).map(Sender::new) } } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 9ba55d898a2..df511916c64 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -47,7 +47,7 @@ pub(crate) trait Semaphore { fn is_closed(&self) -> bool; } -struct Chan { +pub(crate) struct Chan { /// Notifies all tasks listening for the receiver being dropped. notify_rx_closed: Notify, @@ -130,21 +130,13 @@ impl Tx { Tx { inner: chan } } - pub(super) fn downgrade(self) -> Self { - if self.inner.tx_count.fetch_sub(1, AcqRel) == 1 { - // Close the list, which sends a `Close` message - self.inner.tx.close(); - - // Notify the receiver - self.wake_rx(); - } - - self + pub(super) fn downgrade(&self) -> Arc> { + self.inner.clone() } // Returns the upgraded channel or None if the upgrade failed. - pub(super) fn upgrade(self) -> Option { - let mut tx_count = self.inner.tx_count.load(Acquire); + pub(super) fn upgrade(chan: Arc>) -> Option { + let mut tx_count = chan.tx_count.load(Acquire); loop { if tx_count == 0 { @@ -152,12 +144,11 @@ impl Tx { return None; } - match self - .inner + match chan .tx_count .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire) { - Ok(_) => return Some(self), + Ok(_) => return Some(Tx { inner: chan }), Err(prev_count) => tx_count = prev_count, } } diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index 9be51fc1d03..c4b9baf01df 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -667,9 +667,10 @@ fn recv_timeout_panic() { #[tokio::test] async fn weak_sender() { let (tx, mut rx) = channel(11); - let tx_weak = tx.clone().downgrade(); let tx_weak = tokio::spawn(async move { + let tx_weak = tx.clone().downgrade(); + for i in 0..10 { if tx.send(i).await.is_err() { return None; @@ -703,10 +704,9 @@ async fn weak_sender() { } } - if let Some(tx_weak) = tx_weak { - let upgraded = tx_weak.upgrade(); - assert!(upgraded.is_none()); - } + let tx_weak = tx_weak.unwrap(); + let upgraded = tx_weak.upgrade(); + assert!(upgraded.is_none()); } #[tokio::test] @@ -887,6 +887,7 @@ async fn downgrade_upgrade_sender_success() { async fn downgrade_upgrade_sender_failure() { let (tx, _rx) = mpsc::channel::(1); let weak_tx = tx.downgrade(); + drop(tx); assert!(weak_tx.upgrade().is_none()); } @@ -921,3 +922,14 @@ async fn downgrade_upgrade_get_permit_no_senders() { let weak_tx = tx2.downgrade(); assert!(weak_tx.upgrade().is_some()); } + +// Tests that `Clone` of `WeakSender` doesn't decrement `tx_count`. +#[tokio::test] +async fn test_weak_sender_clone() { + let (tx, _rx) = mpsc::channel::(1); + let tx_weak = tx.downgrade(); + let tx_weak2 = tx.downgrade(); + drop(tx); + + assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none()); +} From 59a6dfec3cb421967870771b441a68ae64381c72 Mon Sep 17 00:00:00 2001 From: b-naber Date: Wed, 27 Jul 2022 11:35:57 +0200 Subject: [PATCH 12/13] address review --- tokio/src/sync/mpsc/bounded.rs | 4 --- tokio/src/sync/mpsc/chan.rs | 2 +- tokio/tests/sync_mpsc.rs | 52 +++++++++++++--------------------- 3 files changed, 21 insertions(+), 37 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 015f01aedf5..a62952c51e8 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -52,7 +52,6 @@ pub struct Sender { /// drop(tx); /// assert!(tx_weak.clone().upgrade().is_none()); /// } -/// /// ``` pub struct WeakSender { chan: Arc>, @@ -1033,9 +1032,6 @@ impl Sender { /// channel were dropped and only `WeakSender` instances remain, /// the channel is closed. pub fn downgrade(&self) -> WeakSender { - // Note: If this is the last `Sender` instance we want to close the - // channel when downgrading, so it's important to move into `self` here. - WeakSender { chan: self.chan.downgrade(), } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index df511916c64..a10ffb7d797 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -47,7 +47,7 @@ pub(crate) trait Semaphore { fn is_closed(&self) -> bool; } -pub(crate) struct Chan { +pub(super) struct Chan { /// Notifies all tasks listening for the receiver being dropped. notify_rx_closed: Notify, diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index c4b9baf01df..2dd0f6f57b6 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -10,17 +10,14 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))] use tokio::test as maybe_tokio_test; -use tokio::join; use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; use tokio::sync::mpsc::{self, channel}; use tokio::sync::oneshot; -use tokio::time; use tokio_test::*; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{Acquire, Release}; use std::sync::Arc; -use std::time::Duration; #[cfg(not(target_arch = "wasm32"))] mod support { @@ -838,47 +835,36 @@ impl Drop for Msg { // `Sender` was dropped while more than one `WeakSender` remains, we want to // ensure that no messages are kept in the channel, which were sent after // the receiver was dropped. -#[tokio::test(start_paused = true)] +#[tokio::test] async fn test_msgs_dropped_on_rx_drop() { - fn ms(millis: u64) -> Duration { - Duration::from_millis(millis) - } - let (tx, mut rx) = mpsc::channel(3); - let rx_handle = tokio::spawn(async move { - let _ = rx.recv().await.unwrap(); - let _ = rx.recv().await.unwrap(); - time::sleep(ms(1)).await; - drop(rx); + let _ = tx.send(Msg {}).await.unwrap(); + let _ = tx.send(Msg {}).await.unwrap(); - time::advance(ms(1)).await; - }); + // This msg will be pending and should be dropped when `rx` is dropped + let sent_fut = tx.send(Msg {}); - let tx_handle = tokio::spawn(async move { - let _ = tx.send(Msg {}).await.unwrap(); - let _ = tx.send(Msg {}).await.unwrap(); + let _ = rx.recv().await.unwrap(); + let _ = rx.recv().await.unwrap(); - // This msg will be pending and should be dropped when `rx` is dropped - let _ = tx.send(Msg {}).await.unwrap(); - time::advance(ms(1)).await; + let _ = sent_fut.await.unwrap(); - // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. - time::sleep(ms(1)).await; - assert!(tx.send(Msg {}).await.is_err()); + drop(rx); - // Ensure that third message isn't put onto the channel anymore - assert_eq!(NUM_DROPPED.load(Acquire), 4); - }); + assert_eq!(NUM_DROPPED.load(Acquire), 3); - let (_, _) = join!(rx_handle, tx_handle); + // This msg will not be put onto `Tx` list anymore, since `Rx` is closed. + assert!(tx.send(Msg {}).await.is_err()); + + assert_eq!(NUM_DROPPED.load(Acquire), 4); } // Tests that a `WeakSender` is upgradeable when other `Sender`s exist. #[tokio::test] async fn downgrade_upgrade_sender_success() { let (tx, _rx) = mpsc::channel::(1); - let weak_tx = tx.clone().downgrade(); + let weak_tx = tx.downgrade(); assert!(weak_tx.upgrade().is_some()); } @@ -897,6 +883,7 @@ async fn downgrade_upgrade_sender_failure() { async fn downgrade_drop_upgrade() { let (tx, _rx) = mpsc::channel::(1); + // the cloned `Tx` is dropped right away let weak_tx = tx.clone().downgrade(); drop(tx); assert!(weak_tx.upgrade().is_none()); @@ -907,7 +894,7 @@ async fn downgrade_drop_upgrade() { #[tokio::test] async fn downgrade_get_permit_upgrade_no_senders() { let (tx, _rx) = mpsc::channel::(1); - let weak_tx = tx.clone().downgrade(); + let weak_tx = tx.downgrade(); let _permit = tx.reserve_owned().await.unwrap(); assert!(weak_tx.upgrade().is_some()); } @@ -920,12 +907,13 @@ async fn downgrade_upgrade_get_permit_no_senders() { let tx2 = tx.clone(); let _permit = tx.reserve_owned().await.unwrap(); let weak_tx = tx2.downgrade(); + drop(tx2); assert!(weak_tx.upgrade().is_some()); } -// Tests that `Clone` of `WeakSender` doesn't decrement `tx_count`. +// Tests that `downgrade` does not change the `tx_count` of the channel. #[tokio::test] -async fn test_weak_sender_clone() { +async fn test_tx_count_weak_sender() { let (tx, _rx) = mpsc::channel::(1); let tx_weak = tx.downgrade(); let tx_weak2 = tx.downgrade(); From 188f4eea6200ed15216d00a7b04eb070ae121221 Mon Sep 17 00:00:00 2001 From: b-naber Date: Wed, 27 Jul 2022 12:14:05 +0200 Subject: [PATCH 13/13] fix doc test for WeakSender --- tokio/src/sync/mpsc/bounded.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index a62952c51e8..47d7938158a 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -43,10 +43,10 @@ pub struct Sender { /// #[tokio::main] /// async fn main() { /// let (tx, _rx) = channel::(15); -/// let tx_weak = tx.clone().downgrade(); +/// let tx_weak = tx.downgrade(); /// /// // Upgrading will succeed because `tx` still exists. -/// assert!(tx_weak.clone().upgrade().is_some()); +/// assert!(tx_weak.upgrade().is_some()); /// /// // If we drop `tx`, then it will fail. /// drop(tx);