diff --git a/src/imp_std.rs b/src/imp_std.rs index 24066f1..e78c9e9 100644 --- a/src/imp_std.rs +++ b/src/imp_std.rs @@ -16,9 +16,15 @@ use crate::take_unchecked; #[derive(Debug)] pub(crate) struct OnceCell { - // This `state` word is actually an encoded version of just a pointer to a - // `Waiter`, so we add the `PhantomData` appropriately. - state_and_queue: AtomicUsize, + // This `queue` field is the core of the implementation. It encodes two + // pieces of information: + // + // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) + // * Linked list of threads waiting for the current cell. + // + // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states + // allow waiters. + queue: AtomicUsize, _marker: PhantomData<*mut Waiter>, value: UnsafeCell>, } @@ -34,36 +40,10 @@ unsafe impl Send for OnceCell {} impl RefUnwindSafe for OnceCell {} impl UnwindSafe for OnceCell {} -// Three states that a OnceCell can be in, encoded into the lower bits of `state` in -// the OnceCell structure. -const INCOMPLETE: usize = 0x0; -const RUNNING: usize = 0x1; -const COMPLETE: usize = 0x2; - -// Mask to learn about the state. All other bits are the queue of waiters if -// this is in the RUNNING state. -const STATE_MASK: usize = 0x3; - -// Representation of a node in the linked list of waiters in the RUNNING state. -#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. -struct Waiter { - thread: Cell>, - signaled: AtomicBool, - next: *const Waiter, -} - -// Head of a linked list of waiters. -// Every node is a struct on the stack of a waiting thread. -// Will wake up the waiters when it gets dropped, i.e. also on panic. -struct WaiterQueue<'a> { - state_and_queue: &'a AtomicUsize, - set_state_on_drop_to: usize, -} - impl OnceCell { pub(crate) const fn new() -> OnceCell { OnceCell { - state_and_queue: AtomicUsize::new(INCOMPLETE), + queue: AtomicUsize::new(INCOMPLETE), _marker: PhantomData, value: UnsafeCell::new(None), } @@ -71,7 +51,7 @@ impl OnceCell { pub(crate) const fn with_value(value: T) -> OnceCell { OnceCell { - state_and_queue: AtomicUsize::new(COMPLETE), + queue: AtomicUsize::new(COMPLETE), _marker: PhantomData, value: UnsafeCell::new(Some(value)), } @@ -84,7 +64,7 @@ impl OnceCell { // operations visible to us, and, this being a fast path, weaker // ordering helps with performance. This `Acquire` synchronizes with // `SeqCst` operations on the slow path. - self.state_and_queue.load(Ordering::Acquire) == COMPLETE + self.queue.load(Ordering::Acquire) == COMPLETE } /// Safety: synchronizes with store to value via SeqCst read from state, @@ -99,7 +79,7 @@ impl OnceCell { let mut res: Result<(), E> = Ok(()); let slot: *mut Option = self.value.get(); initialize_or_wait( - &self.state_and_queue, + &self.queue, Some(&mut || { let f = unsafe { take_unchecked(&mut f) }; match f() { @@ -119,7 +99,7 @@ impl OnceCell { #[cold] pub(crate) fn wait(&self) { - initialize_or_wait(&self.state_and_queue, None); + initialize_or_wait(&self.queue, None); } /// Get the reference to the underlying value, without checking if the cell @@ -160,7 +140,55 @@ impl OnceCell { } } -// Corresponds to `std::sync::Once::call_inner` +// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in +// the OnceCell structure. +const INCOMPLETE: usize = 0x0; +const RUNNING: usize = 0x1; +const COMPLETE: usize = 0x2; + +// Mask to learn about the state. All other bits are the queue of waiters if +// this is in the RUNNING state. +const STATE_MASK: usize = 0x3; + +/// Representation of a node in the linked list of waiters in the RUNNING state. +/// A waiters is stored on the stack of the waiting threads. +#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. +struct Waiter { + thread: Cell>, + signaled: AtomicBool, + next: *const Waiter, +} + +/// Drains and notifies the queue of waiters on drop. +struct Guard<'a> { + queue: &'a AtomicUsize, + new_queue: usize, +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + let queue = self.queue.swap(self.new_queue, Ordering::AcqRel); + + assert_eq!(queue & STATE_MASK, RUNNING); + + unsafe { + let mut waiter = (queue & !STATE_MASK) as *const Waiter; + while !waiter.is_null() { + let next = (*waiter).next; + let thread = (*waiter).thread.take().unwrap(); + (*waiter).signaled.store(true, Ordering::Release); + waiter = next; + thread.unpark(); + } + } + } +} + +// Corresponds to `std::sync::Once::call_inner`. +// +// Originally copied from std, but since modified to remove poisoning and to +// support wait. +// // Note: this is intentionally monomorphic #[inline(never)] fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) { @@ -177,16 +205,13 @@ fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> Ordering::Acquire, Ordering::Acquire, ); - if let Err(old) = exchange { - curr_queue = old; + if let Err(new_queue) = exchange { + curr_queue = new_queue; continue; } - let mut waiter_queue = WaiterQueue { - state_and_queue: queue, - set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED` - }; + let mut guard = Guard { queue, new_queue: INCOMPLETE }; if init() { - waiter_queue.set_state_on_drop_to = COMPLETE; + guard.new_queue = COMPLETE; } return; } @@ -194,12 +219,13 @@ fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> wait(&queue, curr_queue); curr_queue = queue.load(Ordering::Acquire); } - _ => unreachable!(), + _ => debug_assert!(false), } } } fn wait(queue: &AtomicUsize, mut curr_queue: usize) { + std::sync::Once let curr_state = curr_queue & STATE_MASK; loop { let node = Waiter { @@ -215,11 +241,11 @@ fn wait(queue: &AtomicUsize, mut curr_queue: usize) { Ordering::Release, Ordering::Relaxed, ); - if let Err(old) = exchange { - if old & STATE_MASK != curr_state { + if let Err(new_queue) = exchange { + if new_queue & STATE_MASK != curr_state { return; } - curr_queue = old; + curr_queue = new_queue; continue; } @@ -230,27 +256,6 @@ fn wait(queue: &AtomicUsize, mut curr_queue: usize) { } } -// Copy-pasted from std exactly. -impl Drop for WaiterQueue<'_> { - fn drop(&mut self) { - let state_and_queue = - self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel); - - assert_eq!(state_and_queue & STATE_MASK, RUNNING); - - unsafe { - let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter; - while !queue.is_null() { - let next = (*queue).next; - let thread = (*queue).thread.replace(None).unwrap(); - (*queue).signaled.store(true, Ordering::Release); - queue = next; - thread.unpark(); - } - } - } -} - // These test are snatched from std as well. #[cfg(test)] mod tests {