Skip to content

Commit

Permalink
replace the use of crossbeam::ArrayQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Feb 20, 2022
1 parent c0aa14b commit 1aac922
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 77 deletions.
300 changes: 243 additions & 57 deletions src/buffer.rs
@@ -1,55 +1,235 @@
use crossbeam_queue::ArrayQueue;
use alloc::boxed::Box;
use core::sync::atomic::{self, AtomicUsize, Ordering};
use core::{cell::UnsafeCell, mem::MaybeUninit};
use crossbeam_utils::{Backoff, CachePadded};
use derivative::Derivative;

struct Slot<T> {
// If the stamp equals the tail, this node will be next written to.
// If it equals head + 1, this node will be next read from.
stamp: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}

impl<T> Slot<T> {
fn new(stamp: usize) -> Self {
Slot {
stamp: AtomicUsize::new(stamp),
value: UnsafeCell::new(MaybeUninit::uninit()),
}
}
}

pub struct AtomicQueue<T> {
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
buffer: Box<[CachePadded<Slot<T>>]>,
one_lap: usize,
}

unsafe impl<T: Send> Sync for AtomicQueue<T> {}
unsafe impl<T: Send> Send for AtomicQueue<T> {}

impl<T> AtomicQueue<T> {
fn new(capacity: usize) -> AtomicQueue<T> {
AtomicQueue {
buffer: (0..capacity).map(Slot::new).map(CachePadded::new).collect(),
head: Default::default(),
tail: Default::default(),
one_lap: (capacity + 1).next_power_of_two(),
}
}

#[inline]
fn capacity(&self) -> usize {
self.buffer.len()
}

#[inline]
fn get(&self, cursor: usize) -> &Slot<T> {
let index = cursor & (self.one_lap - 1);
debug_assert!(index < self.capacity());
unsafe { self.buffer.get_unchecked(index) }
}

#[inline]
fn advance(&self, cursor: usize) -> usize {
let index = cursor & (self.one_lap - 1);
let lap = cursor & !(self.one_lap - 1);

if index + 1 < self.capacity() {
// Same lap, incremented index.
// Set to `{ lap: lap, index: index + 1 }`.
cursor + 1
} else {
// One lap forward, index wraps around to zero.
// Set to `{ lap: lap.wrapping_add(1), index: 0 }`.
lap.wrapping_add(self.one_lap)
}
}

fn push_or_swap(&self, value: T) -> Option<T> {
let backoff = Backoff::new();
let mut tail = self.tail.load(Ordering::Relaxed);

loop {
let new_tail = self.advance(tail);
let slot = self.get(tail);
let stamp = slot.stamp.load(Ordering::Acquire);

// If the tail and the stamp match, we may attempt to push.
if stamp == tail {
// Try moving the tail.
match self.tail.compare_exchange_weak(
tail,
new_tail,
Ordering::SeqCst,
Ordering::Relaxed,
) {
Ok(_) => {
// Write the value into the slot.
unsafe { slot.value.get().write(MaybeUninit::new(value)) };
slot.stamp.store(tail + 1, Ordering::Release);
return None;
}
Err(t) => {
tail = t;
backoff.spin();
continue;
}
}
} else if stamp.wrapping_add(self.one_lap) == tail + 1 {
atomic::fence(Ordering::SeqCst);
let head = self.head.load(Ordering::Relaxed);

// If the head lags one lap behind the tail as well, the queue is full.
if head.wrapping_add(self.one_lap) == tail {
let new_head = new_tail.wrapping_sub(self.one_lap);

// Try moving the head.
if self
.head
.compare_exchange_weak(head, new_head, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
// Move the tail.
debug_assert_eq!(self.tail.load(Ordering::SeqCst), tail);
self.tail.store(new_tail, Ordering::SeqCst);

// Replace the value in the slot.
let new = MaybeUninit::new(value);
let old = unsafe { slot.value.get().replace(new).assume_init() };
slot.stamp.store(tail + 1, Ordering::Release);
return Some(old);
}
}
}

backoff.snooze();
tail = self.tail.load(Ordering::Relaxed);
}
}

fn pop(&self) -> Option<T> {
let backoff = Backoff::new();
let mut head = self.head.load(Ordering::Relaxed);

loop {
let slot = self.get(head);
let stamp = slot.stamp.load(Ordering::Acquire);

// If the the stamp is ahead of the head by 1, we may attempt to pop.
if stamp == head + 1 {
let new_head = self.advance(head);
match self.head.compare_exchange_weak(
head,
new_head,
Ordering::SeqCst,
Ordering::Relaxed,
) {
Ok(_) => {
// Read the value from the slot.
let msg = unsafe { slot.value.get().read().assume_init() };
slot.stamp
.store(head.wrapping_add(self.one_lap), Ordering::Release);
return Some(msg);
}
Err(h) => {
head = h;
backoff.spin();
continue;
}
}
} else if stamp == head {
atomic::fence(Ordering::SeqCst);
let tail = self.tail.load(Ordering::Relaxed);

// If the tail equals the head, that means the channel is empty.
if tail == head {
return None;
}
}

backoff.snooze();
head = self.head.load(Ordering::Relaxed);
}
}
}

impl<T> Drop for AtomicQueue<T> {
fn drop(&mut self) {
let mut cursor = self.head.load(Ordering::Relaxed);
let end = self.tail.load(Ordering::Relaxed);

// Loop over all slots that hold a message and drop them.
while cursor != end {
let slot = self.get(cursor);
unsafe { (&mut *slot.value.get()).as_mut_ptr().drop_in_place() };
cursor = self.advance(cursor);
}
}
}

type AtomicOption<T> = crossbeam_utils::atomic::AtomicCell<Option<T>>;

#[derive(Derivative)]
#[derivative(Debug)]
#[allow(clippy::large_enum_variant)]
pub(super) enum RingBuffer<T> {
Queue(#[derivative(Debug = "ignore")] ArrayQueue<T>),
Cell(#[derivative(Debug = "ignore")] AtomicOption<T>),
AtomicOption(#[derivative(Debug = "ignore")] AtomicOption<T>),
AtomicQueue(#[derivative(Debug = "ignore")] AtomicQueue<T>),
}

impl<T> RingBuffer<T> {
pub(super) fn new(capacity: usize) -> Self {
if capacity > 1 || !AtomicOption::<T>::is_lock_free() {
RingBuffer::Queue(ArrayQueue::new(capacity))
assert!(capacity > 0, "capacity must be non-zero");

if capacity == 1 && AtomicOption::<T>::is_lock_free() {
RingBuffer::AtomicOption(AtomicOption::new(None))
} else {
RingBuffer::Cell(AtomicOption::new(None))
RingBuffer::AtomicQueue(AtomicQueue::new(capacity))
}
}

#[cfg(test)]
pub(super) fn capacity(&self) -> usize {
use RingBuffer::*;
match self {
Queue(q) => q.capacity(),
Cell(_) => 1,
RingBuffer::AtomicOption(_) => 1,
RingBuffer::AtomicQueue(q) => q.capacity(),
}
}

pub(super) fn push(&self, mut value: T) {
use RingBuffer::*;
pub(super) fn push(&self, value: T) -> Option<T> {
match self {
Queue(q) => {
while let Err(v) = q.push(value) {
self.pop();
value = v;
}
}

Cell(c) => {
c.swap(Some(value));
}
RingBuffer::AtomicOption(c) => c.swap(Some(value)),
RingBuffer::AtomicQueue(q) => q.push_or_swap(value),
}
}

pub(super) fn pop(&self) -> Option<T> {
use RingBuffer::*;
match self {
Queue(q) => q.pop(),
Cell(c) => c.swap(None),
RingBuffer::AtomicOption(c) => c.swap(None),
RingBuffer::AtomicQueue(q) => q.pop(),
}
}
}
Expand All @@ -58,38 +238,44 @@ impl<T> RingBuffer<T> {
mod tests {
use super::*;
use crate::{RingReceiver, RingSender};
use alloc::{sync::Arc, vec::Vec};
use core::{cmp::max, mem::discriminant};
use futures::{future::try_join, prelude::*, stream::repeat};
use alloc::{collections::BinaryHeap, sync::Arc, vec::Vec};
use core::{iter, mem::discriminant};
use futures::future::try_join_all;
use proptest::collection::size_range;
use test_strategy::proptest;
use tokio::{runtime, task::spawn_blocking};

#[should_panic]
#[proptest]
fn new_panics_if_capacity_is_zero() {
RingBuffer::<()>::new(0);
}

#[proptest]
fn new_uses_atomic_cell_when_possible() {
assert_eq!(
discriminant(&RingBuffer::<[char; 1]>::new(1)),
discriminant(&RingBuffer::Cell(Default::default()))
discriminant(&RingBuffer::AtomicOption(Default::default()))
);

assert_eq!(
discriminant(&RingBuffer::<[char; 1]>::new(2)),
discriminant(&RingBuffer::Queue(ArrayQueue::new(2)))
discriminant(&RingBuffer::AtomicQueue(AtomicQueue::new(2)))
);

assert_eq!(
discriminant(&RingBuffer::<[char; 4]>::new(1)),
discriminant(&RingBuffer::Queue(ArrayQueue::new(1)))
discriminant(&RingBuffer::AtomicQueue(AtomicQueue::new(1)))
);

assert_eq!(
discriminant(&RingBuffer::<RingSender<()>>::new(1)),
discriminant(&RingBuffer::Cell(Default::default()))
discriminant(&RingBuffer::AtomicOption(Default::default()))
);

assert_eq!(
discriminant(&RingBuffer::<RingReceiver<()>>::new(1)),
discriminant(&RingBuffer::Cell(Default::default()))
discriminant(&RingBuffer::AtomicOption(Default::default()))
);
}

Expand All @@ -101,48 +287,48 @@ mod tests {

#[proptest]
fn oldest_items_are_overwritten_on_overflow(
#[any(size_range(1..=10).lift())] items: Vec<char>,
#[strategy(1..=10usize)] capacity: usize,
#[any(size_range(#capacity..=10).lift())] items: Vec<char>,
) {
let buffer = RingBuffer::new(capacity);

for &item in &items {
buffer.push(item);
for &item in &items[..capacity] {
assert_eq!(buffer.push(item), None);
}

for &item in items.iter().skip(max(items.len(), capacity) - capacity) {
assert_eq!(buffer.pop(), Some(item));
for (i, &item) in (0..(items.len() - capacity)).zip(&items[capacity..]) {
assert_eq!(buffer.push(item), Some(items[i]));
}

for _ in items.len()..max(items.len(), capacity) {
assert_eq!(buffer.pop(), None);
for &item in &items[(items.len() - capacity)..] {
assert_eq!(buffer.pop(), Some(item));
}
}

#[cfg(not(miri))] // https://github.com/rust-lang/miri/issues/1388
#[proptest]
fn buffer_is_thread_safe(
#[strategy(1..=10usize)] m: usize,
fn buffer_is_linearizable(
#[strategy(1..=10usize)] n: usize,
#[strategy(1..=10usize)] capacity: usize,
) {
let rt = runtime::Builder::new_multi_thread().build()?;
let buffer = Arc::new(RingBuffer::new(capacity));

rt.block_on(try_join(
repeat(buffer.clone())
.enumerate()
.take(m)
.map(Ok)
.try_for_each_concurrent(None, |(item, b)| spawn_blocking(move || b.push(item))),
repeat(buffer)
.take(n)
.map(Ok)
.try_for_each_concurrent(None, |b| {
spawn_blocking(move || {
b.pop();
})
}),
))?;
let items = rt.block_on(async {
try_join_all(iter::repeat(buffer).enumerate().take(n).map(|(i, b)| {
spawn_blocking(move || match b.push(i) {
None => b.pop(),
item => item,
})
}))
.await
})?;

let sorted = items
.into_iter()
.flatten()
.collect::<BinaryHeap<_>>()
.into_sorted_vec();

assert_eq!(sorted, (0..n).collect::<Vec<_>>());
}
}

0 comments on commit 1aac922

Please sign in to comment.