From 58a96285b19e7c7a6b8b588bde19a52271d40672 Mon Sep 17 00:00:00 2001 From: Aleksey Kladov Date: Wed, 18 May 2022 17:25:40 +0100 Subject: [PATCH] add OnceCell::wait closes: #102 --- CHANGELOG.md | 2 +- Cargo.toml | 2 +- src/imp_pl.rs | 15 ++++ src/imp_std.rs | 213 ++++++++++++++++++++++++++----------------------- src/lib.rs | 11 +++ tests/it.rs | 20 ++++- 6 files changed, 156 insertions(+), 107 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36cc79a..2126a57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Unreleased -- +- Add `OnceCell::wait`, a blocking variant of `get`. ## 1.11 diff --git a/Cargo.toml b/Cargo.toml index 34a75c2..daba1e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "once_cell" -version = "1.11.0" +version = "1.12.0-pre.1" authors = ["Aleksey Kladov "] license = "MIT OR Apache-2.0" edition = "2018" diff --git a/src/imp_pl.rs b/src/imp_pl.rs index c499e0e..2bd80fa 100644 --- a/src/imp_pl.rs +++ b/src/imp_pl.rs @@ -77,6 +77,21 @@ impl OnceCell { res } + #[cold] + pub(crate) fn wait(&self) { + let key = &self.state as *const _ as usize; + unsafe { + parking_lot_core::park( + key, + || self.state.load(Ordering::Acquire) != COMPLETE, + || (), + |_, _| (), + parking_lot_core::DEFAULT_PARK_TOKEN, + None, + ); + } + } + /// Get the reference to the underlying value, without checking if the cell /// is initialized. /// diff --git a/src/imp_std.rs b/src/imp_std.rs index 9a1c83d..c01078c 100644 --- a/src/imp_std.rs +++ b/src/imp_std.rs @@ -16,9 +16,15 @@ use crate::take_unchecked; #[derive(Debug)] pub(crate) struct OnceCell { - // This `state` word is actually an encoded version of just a pointer to a - // `Waiter`, so we add the `PhantomData` appropriately. - state_and_queue: AtomicUsize, + // This `queue` field is the core of the implementation. It encodes two + // pieces of information: + // + // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) + // * Linked list of threads waiting for the current cell. + // + // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states + // allow waiters. + queue: AtomicUsize, _marker: PhantomData<*mut Waiter>, value: UnsafeCell>, } @@ -34,36 +40,10 @@ unsafe impl Send for OnceCell {} impl RefUnwindSafe for OnceCell {} impl UnwindSafe for OnceCell {} -// Three states that a OnceCell can be in, encoded into the lower bits of `state` in -// the OnceCell structure. -const INCOMPLETE: usize = 0x0; -const RUNNING: usize = 0x1; -const COMPLETE: usize = 0x2; - -// Mask to learn about the state. All other bits are the queue of waiters if -// this is in the RUNNING state. -const STATE_MASK: usize = 0x3; - -// Representation of a node in the linked list of waiters in the RUNNING state. -#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. -struct Waiter { - thread: Cell>, - signaled: AtomicBool, - next: *const Waiter, -} - -// Head of a linked list of waiters. -// Every node is a struct on the stack of a waiting thread. -// Will wake up the waiters when it gets dropped, i.e. also on panic. -struct WaiterQueue<'a> { - state_and_queue: &'a AtomicUsize, - set_state_on_drop_to: usize, -} - impl OnceCell { pub(crate) const fn new() -> OnceCell { OnceCell { - state_and_queue: AtomicUsize::new(INCOMPLETE), + queue: AtomicUsize::new(INCOMPLETE), _marker: PhantomData, value: UnsafeCell::new(None), } @@ -71,7 +51,7 @@ impl OnceCell { pub(crate) const fn with_value(value: T) -> OnceCell { OnceCell { - state_and_queue: AtomicUsize::new(COMPLETE), + queue: AtomicUsize::new(COMPLETE), _marker: PhantomData, value: UnsafeCell::new(Some(value)), } @@ -84,7 +64,7 @@ impl OnceCell { // operations visible to us, and, this being a fast path, weaker // ordering helps with performance. This `Acquire` synchronizes with // `SeqCst` operations on the slow path. - self.state_and_queue.load(Ordering::Acquire) == COMPLETE + self.queue.load(Ordering::Acquire) == COMPLETE } /// Safety: synchronizes with store to value via SeqCst read from state, @@ -98,22 +78,30 @@ impl OnceCell { let mut f = Some(f); let mut res: Result<(), E> = Ok(()); let slot: *mut Option = self.value.get(); - initialize_inner(&self.state_and_queue, &mut || { - let f = unsafe { take_unchecked(&mut f) }; - match f() { - Ok(value) => { - unsafe { *slot = Some(value) }; - true - } - Err(err) => { - res = Err(err); - false + initialize_or_wait( + &self.queue, + Some(&mut || { + let f = unsafe { take_unchecked(&mut f) }; + match f() { + Ok(value) => { + unsafe { *slot = Some(value) }; + true + } + Err(err) => { + res = Err(err); + false + } } - } - }); + }), + ); res } + #[cold] + pub(crate) fn wait(&self) { + initialize_or_wait(&self.queue, None); + } + /// Get the reference to the underlying value, without checking if the cell /// is initialized. /// @@ -152,67 +140,111 @@ impl OnceCell { } } -// Corresponds to `std::sync::Once::call_inner` +// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in +// the OnceCell structure. +const INCOMPLETE: usize = 0x0; +const RUNNING: usize = 0x1; +const COMPLETE: usize = 0x2; + +// Mask to learn about the state. All other bits are the queue of waiters if +// this is in the RUNNING state. +const STATE_MASK: usize = 0x3; + +/// Representation of a node in the linked list of waiters in the RUNNING state. +/// A waiters is stored on the stack of the waiting threads. +#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. +struct Waiter { + thread: Cell>, + signaled: AtomicBool, + next: *const Waiter, +} + +/// Drains and notifies the queue of waiters on drop. +struct Guard<'a> { + queue: &'a AtomicUsize, + new_queue: usize, +} + +impl Drop for Guard<'_> { + fn drop(&mut self) { + let queue = self.queue.swap(self.new_queue, Ordering::AcqRel); + + assert_eq!(queue & STATE_MASK, RUNNING); + + unsafe { + let mut waiter = (queue & !STATE_MASK) as *const Waiter; + while !waiter.is_null() { + let next = (*waiter).next; + let thread = (*waiter).thread.take().unwrap(); + (*waiter).signaled.store(true, Ordering::Release); + waiter = next; + thread.unpark(); + } + } + } +} + +// Corresponds to `std::sync::Once::call_inner`. +// +// Originally copied from std, but since modified to remove poisoning and to +// support wait. +// // Note: this is intentionally monomorphic #[inline(never)] -fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool { - let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire); +fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) { + let mut curr_queue = queue.load(Ordering::Acquire); loop { - match state_and_queue { - COMPLETE => return true, - INCOMPLETE => { - let exchange = my_state_and_queue.compare_exchange( - state_and_queue, - RUNNING, + let curr_state = curr_queue & STATE_MASK; + match (curr_state, &mut init) { + (COMPLETE, _) => return, + (INCOMPLETE, Some(init)) => { + let exchange = queue.compare_exchange( + curr_queue, + (curr_queue & !STATE_MASK) | RUNNING, Ordering::Acquire, Ordering::Acquire, ); - if let Err(old) = exchange { - state_and_queue = old; + if let Err(new_queue) = exchange { + curr_queue = new_queue; continue; } - let mut waiter_queue = WaiterQueue { - state_and_queue: my_state_and_queue, - set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED` - }; - let success = init(); - - // Difference, std always uses `COMPLETE` - waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE }; - return success; + let mut guard = Guard { queue, new_queue: INCOMPLETE }; + if init() { + guard.new_queue = COMPLETE; + } + return; } - _ => { - assert!(state_and_queue & STATE_MASK == RUNNING); - wait(&my_state_and_queue, state_and_queue); - state_and_queue = my_state_and_queue.load(Ordering::Acquire); + (INCOMPLETE, None) | (RUNNING, _) => { + wait(&queue, curr_queue); + curr_queue = queue.load(Ordering::Acquire); } + _ => debug_assert!(false), } } } -// Copy-pasted from std exactly. -fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) { +fn wait(queue: &AtomicUsize, mut curr_queue: usize) { + let curr_state = curr_queue & STATE_MASK; loop { - if current_state & STATE_MASK != RUNNING { - return; - } - let node = Waiter { thread: Cell::new(Some(thread::current())), signaled: AtomicBool::new(false), - next: (current_state & !STATE_MASK) as *const Waiter, + next: (curr_queue & !STATE_MASK) as *const Waiter, }; let me = &node as *const Waiter as usize; - let exchange = state_and_queue.compare_exchange( - current_state, - me | RUNNING, + let exchange = queue.compare_exchange( + curr_queue, + me | curr_state, Ordering::Release, Ordering::Relaxed, ); - if let Err(old) = exchange { - current_state = old; + if let Err(new_queue) = exchange { + if new_queue & STATE_MASK != curr_state { + return; + } + curr_queue = new_queue; continue; } @@ -223,27 +255,6 @@ fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) { } } -// Copy-pasted from std exactly. -impl Drop for WaiterQueue<'_> { - fn drop(&mut self) { - let state_and_queue = - self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel); - - assert_eq!(state_and_queue & STATE_MASK, RUNNING); - - unsafe { - let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter; - while !queue.is_null() { - let next = (*queue).next; - let thread = (*queue).thread.replace(None).unwrap(); - (*queue).signaled.store(true, Ordering::Release); - queue = next; - thread.unpark(); - } - } - } -} - // These test are snatched from std as well. #[cfg(test)] mod tests { diff --git a/src/lib.rs b/src/lib.rs index 75d1c06..d716ea0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -875,6 +875,17 @@ pub mod sync { } } + /// Blocks until the value is set by another thread. + pub fn wait(&self) -> &T { + if !self.0.is_initialized() { + self.0.wait() + } + debug_assert!(self.0.is_initialized()); + // Safe b/c of the wait call above and the fact that we didn't + // relinquish our borrow. + unsafe { self.get_unchecked() } + } + /// Gets the mutable reference to the underlying value. /// /// Returns `None` if the cell is empty. diff --git a/tests/it.rs b/tests/it.rs index 18a8094..e36c42a 100644 --- a/tests/it.rs +++ b/tests/it.rs @@ -319,6 +319,16 @@ mod sync { assert_eq!(cell.get(), Some(&"hello".to_string())); } + #[test] + fn wait() { + let cell: OnceCell = OnceCell::new(); + scope(|s| { + s.spawn(|_| cell.set("hello".to_string())); + let greeting = cell.wait(); + assert_eq!(greeting, "hello") + }); + } + #[test] #[cfg_attr(miri, ignore)] // miri doesn't support Barrier fn get_or_init_stress() { @@ -329,16 +339,18 @@ mod sync { .take(n_cells) .collect(); scope(|s| { - for _ in 0..n_threads { - s.spawn(|_| { + for t in 0..n_threads { + let cells = &cells; + s.spawn(move |_| { for (i, (b, s)) in cells.iter().enumerate() { b.wait(); - let j = s.get_or_init(|| i); + let j = if t % 2 == 0 { s.wait() } else { s.get_or_init(|| i) }; assert_eq!(*j, i); } }); } - }).unwrap(); + }) + .unwrap(); } #[test]