Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: fix notify_waiters notifying sequential awaits #5404

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 57 additions & 37 deletions tokio/src/sync/notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ struct Waiter {
/// Waiting task's waker.
waker: Option<Waker>,

/// `true` if the notification has been assigned to this waiter.
notified: Option<NotificationType>,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this field because it was confusing to me because of the similarly named future.

/// Lowest `notify_waiters` call number which should
/// notify this waiter.
notify_waiters_calls: usize,

notification: Option<NotificationType>,

/// Should not be `Unpin`.
_p: PhantomPinned,
Expand Down Expand Up @@ -258,7 +261,7 @@ unsafe impl<'a> Sync for Notified<'a> {}

#[derive(Debug)]
enum State {
Init(usize),
Init,
Waiting,
Done,
}
Expand Down Expand Up @@ -288,12 +291,8 @@ fn get_num_notify_waiters_calls(data: usize) -> usize {
(data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT
}

fn inc_num_notify_waiters_calls(data: usize) -> usize {
data + (1 << NOTIFY_WAITERS_SHIFT)
}

fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) {
data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst);
fn atomic_fetch_inc_num_notify_waiters_calls(data: &AtomicUsize) -> usize {
data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst)
}

impl Notify {
Expand Down Expand Up @@ -387,11 +386,12 @@ impl Notify {
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
state: State::Init,
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
notify_waiters_calls: state >> NOTIFY_WAITERS_SHIFT,
notification: None,
_p: PhantomPinned,
}),
}
Expand Down Expand Up @@ -500,35 +500,40 @@ impl Notify {
/// }
/// ```
pub fn notify_waiters(&self) {
let mut wakers = WakeList::new();

// There are waiters, the lock must be acquired to notify.
let mut waiters = self.waiters.lock();

// The state must be reloaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
let curr = self.state.load(SeqCst);
let mut curr = self.state.load(SeqCst);

// Increment the number of times this method was called to prevent newer
// waiters from being notified by this call.
let call_num_bound = atomic_fetch_inc_num_notify_waiters_calls(&self.state);

if matches!(get_state(curr), EMPTY | NOTIFIED) {
// There are no waiting tasks. All we need to do is increment the
// number of times this method was called.
atomic_inc_num_notify_waiters_calls(&self.state);
// There are no waiting tasks.
return;
}

// At this point, it is guaranteed that the state will not
// concurrently change, as holding the lock is required to
// transition **out** of `WAITING`.
let mut wakers = WakeList::new();

// At this point we are holding the lock, but we may release it
// inside the loop, so the state can concurrently change out
// of `WAITING`.
'outer: loop {
// Filter out waiters created *after* this call.
let mut iter = waiters.drain_filter(|w| w.notify_waiters_calls <= call_num_bound);

while wakers.can_push() {
match waiters.pop_back() {
match iter.next() {
Some(mut waiter) => {
// Safety: `waiters` lock is still held.
// Safety: `waiters` lock is held.
let waiter = unsafe { waiter.as_mut() };

assert!(waiter.notified.is_none());
assert!(waiter.notification.is_none());

waiter.notified = Some(NotificationType::AllWaiters);
waiter.notification = Some(NotificationType::AllWaiters);

if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
Expand All @@ -546,13 +551,18 @@ impl Notify {

// Acquire the lock again.
waiters = self.waiters.lock();

// The state must be reloaded as it could have changed while
// the lock was released.
curr = self.state.load(SeqCst);
}

// All waiters will be notified, the state must be transitioned to
// If all waiters have been notified, the state must be transitioned to
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
// held, a `store` is sufficient.
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new, SeqCst);
if waiters.is_empty() && get_state(curr) == WAITING {
self.state.store(set_state(curr, EMPTY), SeqCst);
}

// Release the lock before notifying
drop(waiters);
Expand Down Expand Up @@ -597,9 +607,9 @@ fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Op
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };

assert!(waiter.notified.is_none());
assert!(waiter.notification.is_none());

waiter.notified = Some(NotificationType::OneWaiter);
waiter.notification = Some(NotificationType::OneWaiter);
let waker = waiter.waker.take();

if waiters.is_empty() {
Expand Down Expand Up @@ -749,7 +759,7 @@ impl Notified<'_> {

loop {
match *state {
Init(initial_notify_waiters_calls) => {
Init => {
let curr = notify.state.load(SeqCst);

// Optimistically try acquiring a pending notification
Expand All @@ -766,6 +776,17 @@ impl Notified<'_> {
return Poll::Ready(());
}

// Safety: the waiter is still not inserted
let initial_notify_waiters_calls =
unsafe { (*waiter.get()).notify_waiters_calls };

// Optimistically check if notify_waiters has been called
// after the future was created.
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
*state = Done;
return Poll::Ready(());
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional check is basically free and should help to avoid waiting for the lock.


// Clone the waker before locking, a waker clone can be
// triggering arbitrary code.
let waker = waker.cloned();
Expand All @@ -777,8 +798,7 @@ impl Notified<'_> {
// Reload the state with the lock held
let mut curr = notify.state.load(SeqCst);

// if notify_waiters has been called after the future
// was created, then we are done
// Check again if notify_waiters has been called in the meantime.
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
*state = Done;
return Poll::Ready(());
Expand Down Expand Up @@ -856,11 +876,11 @@ impl Notified<'_> {
// Safety: called while locked
let w = unsafe { &mut *waiter.get() };

if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
if w.notification.is_some() {
// Our waker has been notified and our waiter is already removed from
// the list. Reset the fields and convert to `Done`.
w.waker = None;
w.notified = None;
w.notification = None;

*state = Done;
} else {
Expand Down Expand Up @@ -933,7 +953,7 @@ impl Drop for Notified<'_> {
// Safety: with the entry removed from the linked list, there can be
// no concurrent access to the entry
if matches!(
unsafe { (*waiter.get()).notified },
unsafe { (*waiter.get()).notification },
Some(NotificationType::OneWaiter)
) {
if let Some(waker) = notify_locked(&mut waiters, &notify.state, notify_state) {
Expand Down
4 changes: 3 additions & 1 deletion tokio/src/util/linked_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ impl<L: Link> Default for LinkedList<L, L::Target> {

// ===== impl DrainFilter =====

cfg_io_readiness! {
feature! {
#![any(feature = "net", feature = "sync")]

pub(crate) struct DrainFilter<'a, T: Link, F> {
list: &'a mut LinkedList<T, T::Target>,
filter: F,
Expand Down
42 changes: 41 additions & 1 deletion tokio/tests/sync_notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#[cfg(tokio_wasm_not_wasi)]
use wasm_bindgen_test::wasm_bindgen_test as test;

use tokio::sync::Notify;
use std::sync::Arc;
use tokio::sync::{oneshot, Notify};
use tokio_test::task::spawn;
use tokio_test::*;

Expand Down Expand Up @@ -225,3 +226,42 @@ fn test_waker_update() {

assert!(future.is_woken());
}

// tokio-rs/tokio#5396
#[tokio::test(flavor = "multi_thread")]
satakuma marked this conversation as resolved.
Show resolved Hide resolved
async fn notify_waiters_sequential() {
let notify = Arc::new(Notify::new());

let (tx, rx) = oneshot::channel();

let receiver = tokio::spawn({
let notify = notify.clone();
async move {
notify.notified().await;

// Poll the second `Notified` future to try to insert
// it to the waiters queue.
let mut second_notified = spawn(notify.notified());
assert_pending!(second_notified.poll());

// Wait for the `notify_waiters` to end and check if we
// are woken up.
rx.await.unwrap();
assert_pending!(second_notified.poll());
}
});

for _ in 1..100000 {
let notify = notify.clone();
tokio::spawn(async move {
notify.notified().await;
});
}

tokio::task::yield_now().await;

notify.notify_waiters();
tx.send(()).unwrap();

receiver.await.unwrap();
}