Skip to content

Commit

Permalink
fix a race with send2
Browse files Browse the repository at this point in the history
  • Loading branch information
estk committed Apr 19, 2022
1 parent 76d5829 commit faa6f62
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
31 changes: 12 additions & 19 deletions tokio/src/sync/broadcast.rs
Expand Up @@ -1026,11 +1026,19 @@ impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let next = self.next;
let shared = self.shared.clone();
let mut tail = shared.tail.lock();

// register interest in the slot that next points to
// let this be lock-free since we're not yet operating on the tail.
let tail_pos = shared.tail.lock().pos;
for n in next..tail_pos {
// register the new receiver with `Tail`
if tail.rx_cnt == MAX_RECEIVERS {
panic!("max receivers");
}
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");

// Register interest in the slots from next to tail.pos.

// We need to hold the lock here to prevent a race with send2 where send2 overwrites
// next or moves past tail before we register interest in the slot.
for n in next..tail.pos {
let idx = (n & shared.mask as u64) as usize;
let slot = shared.buffer[idx].read().unwrap();

Expand All @@ -1040,21 +1048,6 @@ impl<T> Clone for Receiver<T> {
// called concurrently.
slot.rem.fetch_add(1, SeqCst);
}
// tail pos may have changed, we need a locked section here to prevent a race with `Sender::send2`
let mut n = tail_pos.wrapping_sub(1);
let mut tail = shared.tail.lock();
while n <= tail.pos {
let idx = (n & shared.mask as u64) as usize;
let slot = self.shared.buffer[idx].read().unwrap();
slot.rem.fetch_add(1, SeqCst);
n = n.wrapping_add(1);
}

// register the new receiver with `Tail`
if tail.rx_cnt == MAX_RECEIVERS {
panic!("max receivers");
}
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");

drop(tail);

Expand Down
48 changes: 47 additions & 1 deletion tokio/src/sync/tests/loom_broadcast.rs
Expand Up @@ -92,6 +92,52 @@ fn broadcast_two() {
});
}

// An `Arc` is used as the value in order to detect memory leaks.
#[test]
fn broadcast_two_cloned() {
loom::model(|| {
let (tx, mut rx1) = broadcast::channel::<Arc<&'static str>>(16);
let mut rx2 = rx1.clone();

let th1 = thread::spawn(move || {
block_on(async {
let v = assert_ok!(rx1.recv().await);
assert_eq!(*v, "hello");

let v = assert_ok!(rx1.recv().await);
assert_eq!(*v, "world");

match assert_err!(rx1.recv().await) {
Closed => {}
_ => panic!(),
}
});
});

let th2 = thread::spawn(move || {
block_on(async {
let v = assert_ok!(rx2.recv().await);
assert_eq!(*v, "hello");

let v = assert_ok!(rx2.recv().await);
assert_eq!(*v, "world");

match assert_err!(rx2.recv().await) {
Closed => {}
_ => panic!(),
}
});
});

assert_ok!(tx.send(Arc::new("hello")));
assert_ok!(tx.send(Arc::new("world")));
drop(tx);

assert_ok!(th1.join());
assert_ok!(th2.join());
});
}

#[test]
fn broadcast_wrap() {
loom::model(|| {
Expand Down Expand Up @@ -274,7 +320,7 @@ fn drop_multiple_cloned_rx_with_overflow() {

#[test]
fn send_and_rx_clone() {
// test the interraction of Sender::send and Rx::clone
// test the interaction of Sender::send and Rx::clone
loom::model(move || {
let (tx, mut rx) = broadcast::channel(2);

Expand Down

0 comments on commit faa6f62

Please sign in to comment.