From aabfb299f86f211b21f863696d6e8b1f4d98aa8c Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sat, 19 Feb 2022 18:48:58 +0100 Subject: [PATCH] replace the use of crossbeam::ArrayQueue see https://github.com/crossbeam-rs/crossbeam/pull/789 --- src/buffer.rs | 300 +++++++++++++++++++++++++++++++++++++++--------- src/control.rs | 1 + src/waitlist.rs | 49 ++++---- 3 files changed, 269 insertions(+), 81 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 5e48bbd..7ab31ab 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,55 +1,235 @@ -use crossbeam_queue::ArrayQueue; +use alloc::boxed::Box; +use core::sync::atomic::{self, AtomicUsize, Ordering}; +use core::{cell::UnsafeCell, mem::MaybeUninit}; +use crossbeam_utils::{Backoff, CachePadded}; use derivative::Derivative; +struct Slot { + // If the stamp equals the tail, this node will be next written to. + // If it equals head + 1, this node will be next read from. + stamp: AtomicUsize, + value: UnsafeCell>, +} + +impl Slot { + fn new(stamp: usize) -> Self { + Slot { + stamp: AtomicUsize::new(stamp), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } +} + +pub struct AtomicQueue { + head: CachePadded, + tail: CachePadded, + buffer: Box<[CachePadded>]>, + one_lap: usize, +} + +unsafe impl Sync for AtomicQueue {} +unsafe impl Send for AtomicQueue {} + +impl AtomicQueue { + fn new(capacity: usize) -> AtomicQueue { + AtomicQueue { + buffer: (0..capacity).map(Slot::new).map(CachePadded::new).collect(), + head: Default::default(), + tail: Default::default(), + one_lap: (capacity + 1).next_power_of_two(), + } + } + + #[inline] + fn capacity(&self) -> usize { + self.buffer.len() + } + + #[inline] + fn get(&self, cursor: usize) -> &Slot { + let index = cursor & (self.one_lap - 1); + debug_assert!(index < self.capacity()); + unsafe { self.buffer.get_unchecked(index) } + } + + #[inline] + fn advance(&self, cursor: usize) -> usize { + let index = cursor & (self.one_lap - 1); + let lap = cursor & !(self.one_lap - 1); + + if index + 1 < self.capacity() { + // Same lap, incremented index. + // Set to `{ lap: lap, index: index + 1 }`. + cursor + 1 + } else { + // One lap forward, index wraps around to zero. + // Set to `{ lap: lap.wrapping_add(1), index: 0 }`. + lap.wrapping_add(self.one_lap) + } + } + + fn push_or_swap(&self, value: T) -> Option { + let backoff = Backoff::new(); + let mut tail = self.tail.load(Ordering::Relaxed); + + loop { + let new_tail = self.advance(tail); + let slot = self.get(tail); + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the tail and the stamp match, we may attempt to push. + if stamp == tail { + // Try moving the tail. + match self.tail.compare_exchange_weak( + tail, + new_tail, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Write the value into the slot. + unsafe { slot.value.get().write(MaybeUninit::new(value)) }; + slot.stamp.store(tail + 1, Ordering::Release); + return None; + } + Err(t) => { + tail = t; + backoff.spin(); + continue; + } + } + } else if stamp.wrapping_add(self.one_lap) == tail + 1 { + atomic::fence(Ordering::SeqCst); + let head = self.head.load(Ordering::Relaxed); + + // If the head lags one lap behind the tail as well, the queue is full. + if head.wrapping_add(self.one_lap) == tail { + let new_head = new_tail.wrapping_sub(self.one_lap); + + // Try moving the head. + if self + .head + .compare_exchange_weak(head, new_head, Ordering::SeqCst, Ordering::Relaxed) + .is_ok() + { + // Move the tail. + debug_assert_eq!(self.tail.load(Ordering::SeqCst), tail); + self.tail.store(new_tail, Ordering::SeqCst); + + // Replace the value in the slot. + let new = MaybeUninit::new(value); + let old = unsafe { slot.value.get().replace(new).assume_init() }; + slot.stamp.store(tail + 1, Ordering::Release); + return Some(old); + } + } + } + + backoff.snooze(); + tail = self.tail.load(Ordering::Relaxed); + } + } + + fn pop(&self) -> Option { + let backoff = Backoff::new(); + let mut head = self.head.load(Ordering::Relaxed); + + loop { + let slot = self.get(head); + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the the stamp is ahead of the head by 1, we may attempt to pop. + if stamp == head + 1 { + let new_head = self.advance(head); + match self.head.compare_exchange_weak( + head, + new_head, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Read the value from the slot. + let msg = unsafe { slot.value.get().read().assume_init() }; + slot.stamp + .store(head.wrapping_add(self.one_lap), Ordering::Release); + return Some(msg); + } + Err(h) => { + head = h; + backoff.spin(); + continue; + } + } + } else if stamp == head { + atomic::fence(Ordering::SeqCst); + let tail = self.tail.load(Ordering::Relaxed); + + // If the tail equals the head, that means the channel is empty. + if tail == head { + return None; + } + } + + backoff.snooze(); + head = self.head.load(Ordering::Relaxed); + } + } +} + +impl Drop for AtomicQueue { + fn drop(&mut self) { + let mut cursor = self.head.load(Ordering::Relaxed); + let end = self.tail.load(Ordering::Relaxed); + + // Loop over all slots that hold a message and drop them. + while cursor != end { + let slot = self.get(cursor); + unsafe { (&mut *slot.value.get()).as_mut_ptr().drop_in_place() }; + cursor = self.advance(cursor); + } + } +} + type AtomicOption = crossbeam_utils::atomic::AtomicCell>; #[derive(Derivative)] #[derivative(Debug)] #[allow(clippy::large_enum_variant)] pub(super) enum RingBuffer { - Queue(#[derivative(Debug = "ignore")] ArrayQueue), - Cell(#[derivative(Debug = "ignore")] AtomicOption), + AtomicOption(#[derivative(Debug = "ignore")] AtomicOption), + AtomicQueue(#[derivative(Debug = "ignore")] AtomicQueue), } impl RingBuffer { pub(super) fn new(capacity: usize) -> Self { - if capacity > 1 || !AtomicOption::::is_lock_free() { - RingBuffer::Queue(ArrayQueue::new(capacity)) + assert!(capacity > 0, "capacity must be non-zero"); + + if capacity == 1 && AtomicOption::::is_lock_free() { + RingBuffer::AtomicOption(AtomicOption::new(None)) } else { - RingBuffer::Cell(AtomicOption::new(None)) + RingBuffer::AtomicQueue(AtomicQueue::new(capacity)) } } #[cfg(test)] pub(super) fn capacity(&self) -> usize { - use RingBuffer::*; match self { - Queue(q) => q.capacity(), - Cell(_) => 1, + RingBuffer::AtomicOption(_) => 1, + RingBuffer::AtomicQueue(q) => q.capacity(), } } - pub(super) fn push(&self, mut value: T) { - use RingBuffer::*; + pub(super) fn push(&self, value: T) -> Option { match self { - Queue(q) => { - while let Err(v) = q.push(value) { - self.pop(); - value = v; - } - } - - Cell(c) => { - c.swap(Some(value)); - } + RingBuffer::AtomicOption(c) => c.swap(Some(value)), + RingBuffer::AtomicQueue(q) => q.push_or_swap(value), } } pub(super) fn pop(&self) -> Option { - use RingBuffer::*; match self { - Queue(q) => q.pop(), - Cell(c) => c.swap(None), + RingBuffer::AtomicOption(c) => c.swap(None), + RingBuffer::AtomicQueue(q) => q.pop(), } } } @@ -58,38 +238,44 @@ impl RingBuffer { mod tests { use super::*; use crate::{RingReceiver, RingSender}; - use alloc::{sync::Arc, vec::Vec}; - use core::{cmp::max, mem::discriminant}; - use futures::{future::try_join, prelude::*, stream::repeat}; + use alloc::{collections::BinaryHeap, sync::Arc, vec::Vec}; + use core::{iter, mem::discriminant}; + use futures::future::try_join_all; use proptest::collection::size_range; use test_strategy::proptest; use tokio::{runtime, task::spawn_blocking}; + #[should_panic] + #[proptest] + fn new_panics_if_capacity_is_zero() { + RingBuffer::<()>::new(0); + } + #[proptest] fn new_uses_atomic_cell_when_possible() { assert_eq!( discriminant(&RingBuffer::<[char; 1]>::new(1)), - discriminant(&RingBuffer::Cell(Default::default())) + discriminant(&RingBuffer::AtomicOption(Default::default())) ); assert_eq!( discriminant(&RingBuffer::<[char; 1]>::new(2)), - discriminant(&RingBuffer::Queue(ArrayQueue::new(2))) + discriminant(&RingBuffer::AtomicQueue(AtomicQueue::new(2))) ); assert_eq!( discriminant(&RingBuffer::<[char; 4]>::new(1)), - discriminant(&RingBuffer::Queue(ArrayQueue::new(1))) + discriminant(&RingBuffer::AtomicQueue(AtomicQueue::new(1))) ); assert_eq!( discriminant(&RingBuffer::>::new(1)), - discriminant(&RingBuffer::Cell(Default::default())) + discriminant(&RingBuffer::AtomicOption(Default::default())) ); assert_eq!( discriminant(&RingBuffer::>::new(1)), - discriminant(&RingBuffer::Cell(Default::default())) + discriminant(&RingBuffer::AtomicOption(Default::default())) ); } @@ -101,48 +287,48 @@ mod tests { #[proptest] fn oldest_items_are_overwritten_on_overflow( - #[any(size_range(1..=10).lift())] items: Vec, #[strategy(1..=10usize)] capacity: usize, + #[any(size_range(#capacity..=10).lift())] items: Vec, ) { let buffer = RingBuffer::new(capacity); - for &item in &items { - buffer.push(item); + for &item in &items[..capacity] { + assert_eq!(buffer.push(item), None); } - for &item in items.iter().skip(max(items.len(), capacity) - capacity) { - assert_eq!(buffer.pop(), Some(item)); + for (i, &item) in (0..(items.len() - capacity)).zip(&items[capacity..]) { + assert_eq!(buffer.push(item), Some(items[i])); } - for _ in items.len()..max(items.len(), capacity) { - assert_eq!(buffer.pop(), None); + for &item in &items[(items.len() - capacity)..] { + assert_eq!(buffer.pop(), Some(item)); } } - #[cfg(not(miri))] // https://github.com/rust-lang/miri/issues/1388 #[proptest] - fn buffer_is_thread_safe( - #[strategy(1..=10usize)] m: usize, + fn buffer_is_linearizable( #[strategy(1..=10usize)] n: usize, #[strategy(1..=10usize)] capacity: usize, ) { let rt = runtime::Builder::new_multi_thread().build()?; let buffer = Arc::new(RingBuffer::new(capacity)); - rt.block_on(try_join( - repeat(buffer.clone()) - .enumerate() - .take(m) - .map(Ok) - .try_for_each_concurrent(None, |(item, b)| spawn_blocking(move || b.push(item))), - repeat(buffer) - .take(n) - .map(Ok) - .try_for_each_concurrent(None, |b| { - spawn_blocking(move || { - b.pop(); - }) - }), - ))?; + let items = rt.block_on(async { + try_join_all(iter::repeat(buffer).enumerate().take(n).map(|(i, b)| { + spawn_blocking(move || match b.push(i) { + None => b.pop(), + item => item, + }) + })) + .await + })?; + + let sorted = items + .into_iter() + .flatten() + .collect::>() + .into_sorted_vec(); + + assert_eq!(sorted, (0..n).collect::>()); } } diff --git a/src/control.rs b/src/control.rs index 34ec6e6..e82bd98 100644 --- a/src/control.rs +++ b/src/control.rs @@ -20,6 +20,7 @@ pub(super) struct ControlBlock { pub(super) buffer: RingBuffer, #[cfg(feature = "futures_api")] + #[derivative(Debug = "ignore")] pub(super) waitlist: Waitlist, } diff --git a/src/waitlist.rs b/src/waitlist.rs index 7e5c892..a8a6186 100644 --- a/src/waitlist.rs +++ b/src/waitlist.rs @@ -3,7 +3,7 @@ use crossbeam_queue::SegQueue; use crossbeam_utils::CachePadded; use derivative::Derivative; -#[derive(Derivative, Debug)] +#[derive(Derivative)] #[derivative(Default(bound = "", new = "true"))] pub(super) struct Waitlist { len: CachePadded, @@ -19,14 +19,14 @@ impl Waitlist { // The queue is cleared, even if the iterator is not fully consumed. pub(super) fn drain(&self) -> impl Iterator + '_ { Drain { - registry: self, + waitlist: self, count: self.len.swap(0, Ordering::AcqRel), } } } struct Drain<'a, T> { - registry: &'a Waitlist, + waitlist: &'a Waitlist, count: usize, } @@ -45,7 +45,7 @@ impl<'a, T> Iterator for Drain<'a, T> { } loop { - if let item @ Some(_) = self.registry.queue.pop() { + if let item @ Some(_) = self.waitlist.queue.pop() { self.count -= 1; return item; } @@ -62,9 +62,9 @@ impl<'a, T> ExactSizeIterator for Drain<'a, T> {} #[cfg(test)] mod tests { use super::*; - use alloc::{sync::Arc, vec::Vec}; - use core::sync::atomic::Ordering; - use futures::{future::try_join, prelude::*, stream::repeat}; + use alloc::{collections::BinaryHeap, sync::Arc, vec::Vec}; + use core::{iter, sync::atomic::Ordering}; + use futures::future::try_join_all; use proptest::collection::size_range; use test_strategy::proptest; use tokio::{runtime, task::spawn_blocking}; @@ -123,25 +123,26 @@ mod tests { #[cfg(not(miri))] // https://github.com/rust-lang/miri/issues/1388 #[proptest] - fn waitlist_is_thread_safe( - #[strategy(1..=10usize)] m: usize, - #[strategy(1..=10usize)] n: usize, - ) { + fn waitlist_is_linearizable(#[strategy(1..=10usize)] n: usize) { let rt = runtime::Builder::new_multi_thread().build()?; let waitlist = Arc::new(Waitlist::new()); - rt.block_on(try_join( - repeat(waitlist.clone()) - .enumerate() - .take(m) - .map(Ok) - .try_for_each_concurrent(None, |(item, w)| spawn_blocking(move || w.push(item))), - repeat(waitlist) - .take(n) - .map(Ok) - .try_for_each_concurrent(None, |w| { - spawn_blocking(move || w.drain().for_each(drop)) - }), - ))?; + let items = rt.block_on(async { + try_join_all(iter::repeat(waitlist).enumerate().take(n).map(|(i, w)| { + spawn_blocking(move || { + w.push(i); + w.drain().collect::>() + }) + })) + .await + })?; + + let sorted = items + .into_iter() + .flatten() + .collect::>() + .into_sorted_vec(); + + assert_eq!(sorted, (0..n).collect::>()); } }