Skip to content

Commit

Permalink
use Arc in WeakSender
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed Jul 6, 2022
1 parent 45d005d commit fe51bd7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 71 deletions.
5 changes: 4 additions & 1 deletion tokio-util/tests/mpsc.rs
Expand Up @@ -331,7 +331,10 @@ async fn actor_weak_sender() {
async fn send_message_to_self(&mut self) {
let msg = ActorMessage::SelfMessage {};

if let Some(sender) = self.sender.upgrade() {
let sender = self.sender.clone();

// cannot move self.sender here
if let Some(sender) = sender.upgrade() {
let _ = sender.send(msg).await;
self.sender = sender.downgrade();
}
Expand Down
21 changes: 15 additions & 6 deletions tokio/src/sync/mpsc/bounded.rs
Expand Up @@ -63,7 +63,7 @@ pub struct Sender<T> {
///
/// ```
pub struct WeakSender<T> {
chan: chan::TxWeak<T, Semaphore>,
chan: chan::Tx<T, Semaphore>,
}

/// Permits to send one value into the channel.
Expand Down Expand Up @@ -1044,8 +1044,9 @@ impl<T> Sender<T> {
// Note: If this is the last `Sender` instance we want to close the
// channel when downgrading, so it's important to move into `self` here.

let chan = self.chan.downgrade();
WeakSender { chan }
WeakSender {
chan: self.chan.downgrade(),
}
}
}

Expand All @@ -1065,12 +1066,20 @@ impl<T> fmt::Debug for Sender<T> {
}
}

impl<T> Clone for WeakSender<T> {
fn clone(&self) -> Self {
WeakSender {
chan: self.chan.clone(),
}
}
}

impl<T> WeakSender<T> {
/// Tries to conver a WeakSender into a [`Sender`]. This will return `Some`
/// Tries to convert a WeakSender into a [`Sender`]. This will return `Some`
/// if there are other `Sender` instances alive and the channel wasn't
/// previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<Sender<T>> {
self.chan.upgrade().map(Sender::new)
pub fn upgrade(self) -> Option<Sender<T>> {
self.chan.upgrade().map(|tx| Sender::new(tx))
}
}

Expand Down
104 changes: 40 additions & 64 deletions tokio/src/sync/mpsc/chan.rs
Expand Up @@ -9,10 +9,8 @@ use crate::sync::mpsc::list;
use crate::sync::notify::Notify;

use std::fmt;
use std::mem;
use std::process;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
use std::sync::Weak;
use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll};
use std::usize;
Expand All @@ -28,16 +26,6 @@ impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
}
}

pub(crate) struct TxWeak<T, S> {
inner: Weak<Chan<T, S>>,
}

impl<T, S: fmt::Debug> fmt::Debug for TxWeak<T, S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("TxWeak").finish()
}
}

/// Channel receiver.
pub(crate) struct Rx<T, S: Semaphore> {
inner: Arc<Chan<T, S>>,
Expand Down Expand Up @@ -142,12 +130,47 @@ impl<T, S> Tx<T, S> {
Tx { inner: chan }
}

pub(super) fn downgrade(self) -> TxWeak<T, S> {
// We don't decrement the `tx_counter` here, but let the counter be decremented
// through the drop of self.inner.
let weak_inner = Arc::<Chan<T, S>>::downgrade(&self.inner);
pub(super) fn downgrade(self) -> Self {
if self.inner.tx_count.fetch_sub(1, AcqRel) == 1 {
// Close the list, which sends a `Close` message
self.inner.tx.close();

// Notify the receiver
self.wake_rx();
}

self
}

TxWeak::new(weak_inner)
// Returns a boolean that indicates whether the channel is closed.
pub(super) fn upgrade(self) -> Option<Self> {
let mut tx_count = self.inner.tx_count.load(Acquire);

if tx_count == 0 {
// channel is closed
return None;
}

loop {
match self
.inner
.tx_count
.compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire)
{
Ok(prev_count) => {
assert!(prev_count != 0);

return Some(self);
}
Err(prev_count) => {
if prev_count == 0 {
return None;
}

tx_count = prev_count;
}
}
}
}

pub(super) fn semaphore(&self) -> &S {
Expand All @@ -170,53 +193,6 @@ impl<T, S> Tx<T, S> {
}
}

impl<T, S> TxWeak<T, S> {
fn new(inner: Weak<Chan<T, S>>) -> Self {
TxWeak { inner }
}

pub(super) fn upgrade(&self) -> Option<Tx<T, S>> {
let inner = self.inner.upgrade();

if let Some(inner) = inner {
// If we were able to upgrade, `Chan` is guaranteed to still exist,
// even though the channel might have been closed in the meantime.
// Need to check here whether the channel was actually closed.

let mut tx_count = inner.tx_count.load(Acquire);

if tx_count == 0 {
// channel is closed
mem::drop(inner);
return None;
}

loop {
match inner
.tx_count
.compare_exchange(tx_count, tx_count + 1, AcqRel, Acquire)
{
Ok(prev_count) => {
assert!(prev_count != 0);

return Some(Tx::new(inner));
}
Err(prev_count) => {
if prev_count == 0 {
mem::drop(inner);
return None;
}

tx_count = prev_count;
}
}
}
} else {
None
}
}
}

impl<T, S: Semaphore> Tx<T, S> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.semaphore.is_closed()
Expand Down

0 comments on commit fe51bd7

Please sign in to comment.