Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
matklad committed May 19, 2022
1 parent bf91152 commit 81e9390
Showing 1 changed file with 72 additions and 67 deletions.
139 changes: 72 additions & 67 deletions src/imp_std.rs
Expand Up @@ -16,9 +16,15 @@ use crate::take_unchecked;

#[derive(Debug)]
pub(crate) struct OnceCell<T> {
// This `state` word is actually an encoded version of just a pointer to a
// `Waiter`, so we add the `PhantomData` appropriately.
state_and_queue: AtomicUsize,
// This `queue` field is the core of the implementation. It encodes two
// pieces of information:
//
// * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
// * Linked list of threads waiting for the current cell.
//
// State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
// allow waiters.
queue: AtomicUsize,
_marker: PhantomData<*mut Waiter>,
value: UnsafeCell<Option<T>>,
}
Expand All @@ -34,44 +40,18 @@ unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}

// Three states that a OnceCell can be in, encoded into the lower bits of `state` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

// Representation of a node in the linked list of waiters in the RUNNING state.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

// Head of a linked list of waiters.
// Every node is a struct on the stack of a waiting thread.
// Will wake up the waiters when it gets dropped, i.e. also on panic.
struct WaiterQueue<'a> {
state_and_queue: &'a AtomicUsize,
set_state_on_drop_to: usize,
}

impl<T> OnceCell<T> {
pub(crate) const fn new() -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(INCOMPLETE),
queue: AtomicUsize::new(INCOMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(None),
}
}

pub(crate) const fn with_value(value: T) -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(COMPLETE),
queue: AtomicUsize::new(COMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(Some(value)),
}
Expand All @@ -84,7 +64,7 @@ impl<T> OnceCell<T> {
// operations visible to us, and, this being a fast path, weaker
// ordering helps with performance. This `Acquire` synchronizes with
// `SeqCst` operations on the slow path.
self.state_and_queue.load(Ordering::Acquire) == COMPLETE
self.queue.load(Ordering::Acquire) == COMPLETE
}

/// Safety: synchronizes with store to value via SeqCst read from state,
Expand All @@ -99,7 +79,7 @@ impl<T> OnceCell<T> {
let mut res: Result<(), E> = Ok(());
let slot: *mut Option<T> = self.value.get();
initialize_or_wait(
&self.state_and_queue,
&self.queue,
Some(&mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Expand All @@ -119,7 +99,7 @@ impl<T> OnceCell<T> {

#[cold]
pub(crate) fn wait(&self) {
initialize_or_wait(&self.state_and_queue, None);
initialize_or_wait(&self.queue, None);
}

/// Get the reference to the underlying value, without checking if the cell
Expand Down Expand Up @@ -160,7 +140,55 @@ impl<T> OnceCell<T> {
}
}

// Corresponds to `std::sync::Once::call_inner`
// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

/// Representation of a node in the linked list of waiters in the RUNNING state.
/// A waiters is stored on the stack of the waiting threads.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

/// Drains and notifies the queue of waiters on drop.
struct Guard<'a> {
queue: &'a AtomicUsize,
new_queue: usize,
}

impl Drop for Guard<'_> {
fn drop(&mut self) {
let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);

assert_eq!(queue & STATE_MASK, RUNNING);

unsafe {
let mut waiter = (queue & !STATE_MASK) as *const Waiter;
while !waiter.is_null() {
let next = (*waiter).next;
let thread = (*waiter).thread.take().unwrap();
(*waiter).signaled.store(true, Ordering::Release);
waiter = next;
thread.unpark();
}
}
}
}

// Corresponds to `std::sync::Once::call_inner`.
//
// Originally copied from std, but since modified to remove poisoning and to
// support wait.
//
// Note: this is intentionally monomorphic
#[inline(never)]
fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) {
Expand All @@ -177,29 +205,27 @@ fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() ->
Ordering::Acquire,
Ordering::Acquire,
);
if let Err(old) = exchange {
curr_queue = old;
if let Err(new_queue) = exchange {
curr_queue = new_queue;
continue;
}
let mut waiter_queue = WaiterQueue {
state_and_queue: queue,
set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED`
};
let mut guard = Guard { queue, new_queue: INCOMPLETE };
if init() {
waiter_queue.set_state_on_drop_to = COMPLETE;
guard.new_queue = COMPLETE;
}
return;
}
(INCOMPLETE, None) | (RUNNING, _) => {
wait(&queue, curr_queue);
curr_queue = queue.load(Ordering::Acquire);
}
_ => unreachable!(),
_ => debug_assert!(false),
}
}
}

fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
std::sync::Once
let curr_state = curr_queue & STATE_MASK;
loop {
let node = Waiter {
Expand All @@ -215,11 +241,11 @@ fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
Ordering::Release,
Ordering::Relaxed,
);
if let Err(old) = exchange {
if old & STATE_MASK != curr_state {
if let Err(new_queue) = exchange {
if new_queue & STATE_MASK != curr_state {
return;
}
curr_queue = old;
curr_queue = new_queue;
continue;
}

Expand All @@ -230,27 +256,6 @@ fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
}
}

// Copy-pasted from std exactly.
impl Drop for WaiterQueue<'_> {
fn drop(&mut self) {
let state_and_queue =
self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel);

assert_eq!(state_and_queue & STATE_MASK, RUNNING);

unsafe {
let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter;
while !queue.is_null() {
let next = (*queue).next;
let thread = (*queue).thread.replace(None).unwrap();
(*queue).signaled.store(true, Ordering::Release);
queue = next;
thread.unpark();
}
}
}
}

// These test are snatched from std as well.
#[cfg(test)]
mod tests {
Expand Down

0 comments on commit 81e9390

Please sign in to comment.