From 17aee3f7603e5fdbcea620b63bd853fef3f7ac46 Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sat, 19 Feb 2022 18:48:58 +0100 Subject: [PATCH] replace the use of crossbeam::ArrayQueue see https://github.com/crossbeam-rs/crossbeam/pull/789 --- src/buffer.rs | 278 +++++++++++++++++++++++++++++++++++++++--------- src/waitlist.rs | 47 ++++---- 2 files changed, 252 insertions(+), 73 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 440e293..31fc006 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,15 +1,203 @@ use alloc::boxed::Box; -use crossbeam_queue::ArrayQueue; -use crossbeam_utils::atomic::AtomicCell; +use core::sync::atomic::{self, AtomicUsize, Ordering}; +use core::{cell::UnsafeCell, mem::MaybeUninit}; +use crossbeam_utils::{atomic::AtomicCell, Backoff, CachePadded}; use derivative::Derivative; +struct Slot { + // If the stamp equals the tail, this node will be next written to. + // If it equals head + 1, this node will be next read from. + stamp: AtomicUsize, + value: UnsafeCell>, +} + +impl Slot { + fn new(stamp: usize) -> Self { + Slot { + stamp: AtomicUsize::new(stamp), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } +} + +pub struct CircularQueue { + head: CachePadded, + tail: CachePadded, + buffer: Box<[CachePadded>]>, + lap: usize, +} + +unsafe impl Sync for CircularQueue {} +unsafe impl Send for CircularQueue {} + +impl CircularQueue { + fn new(capacity: usize) -> CircularQueue { + CircularQueue { + buffer: (0..capacity).map(Slot::new).map(CachePadded::new).collect(), + head: Default::default(), + tail: Default::default(), + lap: (capacity + 1).next_power_of_two(), + } + } + + #[inline] + fn capacity(&self) -> usize { + self.buffer.len() + } + + #[inline] + fn get(&self, cursor: usize) -> &Slot { + let index = cursor & (self.lap - 1); + debug_assert!(index < self.capacity()); + unsafe { self.buffer.get_unchecked(index) } + } + + fn advance(&self, cursor: usize) -> usize { + let index = cursor & (self.lap - 1); + let stamp = cursor & !(self.lap - 1); + + if index + 1 < self.capacity() { + // Same lap, incremented index. + // Set to `{ stamp: stamp, index: index + 1 }`. + cursor + 1 + } else { + // One lap forward, index wraps around to zero. + // Set to `{ stamp: stamp.wrapping_add(1), index: 0 }`. + stamp.wrapping_add(self.lap) + } + } + + fn push_or_swap(&self, value: T) -> Option { + let backoff = Backoff::new(); + let mut tail = self.tail.load(Ordering::Relaxed); + + loop { + let new_tail = self.advance(tail); + let slot = self.get(tail); + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the stamp matches the tail, we may attempt to push. + if stamp == tail { + // Try advancing the tail. + match self.tail.compare_exchange_weak( + tail, + new_tail, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Write the value into the slot. + unsafe { slot.value.get().write(MaybeUninit::new(value)) }; + slot.stamp.store(tail + 1, Ordering::Release); + return None; + } + + Err(t) => { + tail = t; + backoff.spin(); + continue; + } + } + // If the stamp lags one lap behind the tail, we may attempt to swap. + } else if stamp.wrapping_add(self.lap) == tail + 1 { + atomic::fence(Ordering::SeqCst); + + // Try advancing the head, if it lags one lap behind the tail as well. + if self + .head + .compare_exchange_weak( + tail.wrapping_sub(self.lap), + new_tail.wrapping_sub(self.lap), + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + // Advance the tail. + debug_assert_eq!(self.tail.load(Ordering::SeqCst), tail); + self.tail.store(new_tail, Ordering::SeqCst); + + // Replace the value in the slot. + let new = MaybeUninit::new(value); + let old = unsafe { slot.value.get().replace(new).assume_init() }; + slot.stamp.store(tail + 1, Ordering::Release); + return Some(old); + } + } + + backoff.snooze(); + tail = self.tail.load(Ordering::Relaxed); + } + } + + fn pop(&self) -> Option { + let backoff = Backoff::new(); + let mut head = self.head.load(Ordering::Relaxed); + + loop { + let slot = self.get(head); + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the the stamp is ahead of the head by 1, we may attempt to pop. + if stamp == head + 1 { + // Try advancing the head. + match self.head.compare_exchange_weak( + head, + self.advance(head), + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Read the value from the slot. + let msg = unsafe { slot.value.get().read().assume_init() }; + slot.stamp + .store(head.wrapping_add(self.lap), Ordering::Release); + return Some(msg); + } + + Err(h) => { + head = h; + backoff.spin(); + continue; + } + } + // If the stamp matches the head, the queue may be empty. + } else if stamp == head { + atomic::fence(Ordering::SeqCst); + + // If the tail matches the head as well, the queue is empty. + if self.tail.load(Ordering::Relaxed) == head { + return None; + } + } + + backoff.snooze(); + head = self.head.load(Ordering::Relaxed); + } + } +} + +impl Drop for CircularQueue { + fn drop(&mut self) { + let mut cursor = self.head.load(Ordering::Relaxed); + let end = self.tail.load(Ordering::Relaxed); + + // Loop over all slots that hold a message and drop them. + while cursor != end { + let slot = self.get(cursor); + unsafe { (&mut *slot.value.get()).as_mut_ptr().drop_in_place() }; + cursor = self.advance(cursor); + } + } +} + #[derive(Derivative)] #[derivative(Debug)] #[allow(clippy::large_enum_variant)] pub(super) enum RingBuffer { Atomic(#[derivative(Debug = "ignore")] AtomicCell>), Boxed(#[derivative(Debug = "ignore")] AtomicCell>>), - Queue(#[derivative(Debug = "ignore")] ArrayQueue), + Queue(#[derivative(Debug = "ignore")] CircularQueue), } impl RingBuffer { @@ -22,7 +210,7 @@ impl RingBuffer { debug_assert!(AtomicCell::>>::is_lock_free()); RingBuffer::Boxed(AtomicCell::new(None)) } else { - RingBuffer::Queue(ArrayQueue::new(capacity)) + RingBuffer::Queue(CircularQueue::new(capacity)) } } @@ -35,22 +223,11 @@ impl RingBuffer { } } - pub(super) fn push(&self, mut value: T) { + pub(super) fn push(&self, value: T) -> Option { match self { - RingBuffer::Atomic(c) => { - c.store(Some(value)); - } - - RingBuffer::Boxed(b) => { - b.store(Some(Box::new(value))); - } - - RingBuffer::Queue(q) => { - while let Err(v) = q.push(value) { - self.pop(); - value = v; - } - } + RingBuffer::Atomic(c) => c.swap(Some(value)), + RingBuffer::Boxed(b) => Some(*b.swap(Some(Box::new(value)))?), + RingBuffer::Queue(q) => q.push_or_swap(value), } } @@ -67,9 +244,9 @@ impl RingBuffer { mod tests { use super::*; use crate::{RingReceiver, RingSender}; - use alloc::{sync::Arc, vec::Vec}; - use core::{cmp::max, mem::discriminant}; - use futures::{future::try_join, prelude::*, stream::repeat}; + use alloc::{collections::BinaryHeap, sync::Arc, vec::Vec}; + use core::{iter, mem::discriminant}; + use futures::future::try_join_all; use proptest::collection::size_range; use test_strategy::proptest; use tokio::{runtime, task::spawn_blocking}; @@ -89,7 +266,7 @@ mod tests { assert_eq!( discriminant(&RingBuffer::<[char; 1]>::new(2)), - discriminant(&RingBuffer::Queue(ArrayQueue::new(2))) + discriminant(&RingBuffer::Queue(CircularQueue::new(2))) ); assert_eq!( @@ -99,7 +276,7 @@ mod tests { assert_eq!( discriminant(&RingBuffer::<[char; 4]>::new(2)), - discriminant(&RingBuffer::Queue(ArrayQueue::new(2))) + discriminant(&RingBuffer::Queue(CircularQueue::new(2))) ); assert_eq!( @@ -121,48 +298,49 @@ mod tests { #[proptest] fn oldest_items_are_overwritten_on_overflow( - #[any(size_range(1..=10).lift())] items: Vec, #[strategy(1..=10usize)] capacity: usize, + #[any(size_range(#capacity..=10).lift())] items: Vec, ) { let buffer = RingBuffer::new(capacity); - for &item in &items { - buffer.push(item); + for &item in &items[..capacity] { + assert_eq!(buffer.push(item), None); } - for &item in items.iter().skip(max(items.len(), capacity) - capacity) { - assert_eq!(buffer.pop(), Some(item)); + for (i, &item) in (0..(items.len() - capacity)).zip(&items[capacity..]) { + assert_eq!(buffer.push(item), Some(items[i])); } - for _ in items.len()..max(items.len(), capacity) { - assert_eq!(buffer.pop(), None); - } + assert_eq!( + iter::from_fn(|| buffer.pop()).collect::>(), + items[(items.len() - capacity)..] + ); } - #[cfg(not(miri))] // https://github.com/rust-lang/miri/issues/1388 #[proptest] - fn buffer_is_thread_safe( - #[strategy(1..=10usize)] m: usize, + fn buffer_is_linearizable( #[strategy(1..=10usize)] n: usize, #[strategy(1..=10usize)] capacity: usize, ) { let rt = runtime::Builder::new_multi_thread().build()?; let buffer = Arc::new(RingBuffer::new(capacity)); - rt.block_on(try_join( - repeat(buffer.clone()) - .enumerate() - .take(m) - .map(Ok) - .try_for_each_concurrent(None, |(item, b)| spawn_blocking(move || b.push(item))), - repeat(buffer) - .take(n) - .map(Ok) - .try_for_each_concurrent(None, |b| { - spawn_blocking(move || { - b.pop(); - }) - }), - ))?; + let items = rt.block_on(async { + try_join_all(iter::repeat(buffer).enumerate().take(n).map(|(i, b)| { + spawn_blocking(move || match b.push(i) { + None => b.pop(), + item => item, + }) + })) + .await + })?; + + let sorted = items + .into_iter() + .flatten() + .collect::>() + .into_sorted_vec(); + + assert_eq!(sorted, (0..n).collect::>()); } } diff --git a/src/waitlist.rs b/src/waitlist.rs index 7e5c892..b53b98c 100644 --- a/src/waitlist.rs +++ b/src/waitlist.rs @@ -19,14 +19,14 @@ impl Waitlist { // The queue is cleared, even if the iterator is not fully consumed. pub(super) fn drain(&self) -> impl Iterator + '_ { Drain { - registry: self, + waitlist: self, count: self.len.swap(0, Ordering::AcqRel), } } } struct Drain<'a, T> { - registry: &'a Waitlist, + waitlist: &'a Waitlist, count: usize, } @@ -45,7 +45,7 @@ impl<'a, T> Iterator for Drain<'a, T> { } loop { - if let item @ Some(_) = self.registry.queue.pop() { + if let item @ Some(_) = self.waitlist.queue.pop() { self.count -= 1; return item; } @@ -62,9 +62,9 @@ impl<'a, T> ExactSizeIterator for Drain<'a, T> {} #[cfg(test)] mod tests { use super::*; - use alloc::{sync::Arc, vec::Vec}; - use core::sync::atomic::Ordering; - use futures::{future::try_join, prelude::*, stream::repeat}; + use alloc::{collections::BinaryHeap, sync::Arc, vec::Vec}; + use core::{iter, sync::atomic::Ordering}; + use futures::future::try_join_all; use proptest::collection::size_range; use test_strategy::proptest; use tokio::{runtime, task::spawn_blocking}; @@ -123,25 +123,26 @@ mod tests { #[cfg(not(miri))] // https://github.com/rust-lang/miri/issues/1388 #[proptest] - fn waitlist_is_thread_safe( - #[strategy(1..=10usize)] m: usize, - #[strategy(1..=10usize)] n: usize, - ) { + fn waitlist_is_linearizable(#[strategy(1..=10usize)] n: usize) { let rt = runtime::Builder::new_multi_thread().build()?; let waitlist = Arc::new(Waitlist::new()); - rt.block_on(try_join( - repeat(waitlist.clone()) - .enumerate() - .take(m) - .map(Ok) - .try_for_each_concurrent(None, |(item, w)| spawn_blocking(move || w.push(item))), - repeat(waitlist) - .take(n) - .map(Ok) - .try_for_each_concurrent(None, |w| { - spawn_blocking(move || w.drain().for_each(drop)) - }), - ))?; + let items = rt.block_on(async { + try_join_all(iter::repeat(waitlist).enumerate().take(n).map(|(i, w)| { + spawn_blocking(move || { + w.push(i); + w.drain().collect::>() + }) + })) + .await + })?; + + let sorted = items + .into_iter() + .flatten() + .collect::>() + .into_sorted_vec(); + + assert_eq!(sorted, (0..n).collect::>()); } }