Skip to content

Commit

Permalink
add OnceCell::wait
Browse files Browse the repository at this point in the history
closes: #102
  • Loading branch information
matklad committed May 20, 2022
1 parent 12f283f commit c935154
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 107 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -2,7 +2,7 @@

## Unreleased

-
- Add `OnceCell::wait`, a blocking variant of `get`.

## 1.11

Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "once_cell"
version = "1.11.0"
version = "1.12.0-pre.1"
authors = ["Aleksey Kladov <aleksey.kladov@gmail.com>"]
license = "MIT OR Apache-2.0"
edition = "2018"
Expand Down
15 changes: 15 additions & 0 deletions src/imp_pl.rs
Expand Up @@ -77,6 +77,21 @@ impl<T> OnceCell<T> {
res
}

#[cold]
pub(crate) fn wait(&self) {
let key = &self.state as *const _ as usize;
unsafe {
parking_lot_core::park(
key,
|| self.state.load(Ordering::Acquire) != COMPLETE,
|| (),
|_, _| (),
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
}
}

/// Get the reference to the underlying value, without checking if the cell
/// is initialized.
///
Expand Down
213 changes: 112 additions & 101 deletions src/imp_std.rs
Expand Up @@ -16,9 +16,15 @@ use crate::take_unchecked;

#[derive(Debug)]
pub(crate) struct OnceCell<T> {
// This `state` word is actually an encoded version of just a pointer to a
// `Waiter`, so we add the `PhantomData` appropriately.
state_and_queue: AtomicUsize,
// This `queue` field is the core of the implementation. It encodes two
// pieces of information:
//
// * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
// * Linked list of threads waiting for the current cell.
//
// State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
// allow waiters.
queue: AtomicUsize,
_marker: PhantomData<*mut Waiter>,
value: UnsafeCell<Option<T>>,
}
Expand All @@ -34,44 +40,18 @@ unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}

// Three states that a OnceCell can be in, encoded into the lower bits of `state` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

// Representation of a node in the linked list of waiters in the RUNNING state.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

// Head of a linked list of waiters.
// Every node is a struct on the stack of a waiting thread.
// Will wake up the waiters when it gets dropped, i.e. also on panic.
struct WaiterQueue<'a> {
state_and_queue: &'a AtomicUsize,
set_state_on_drop_to: usize,
}

impl<T> OnceCell<T> {
pub(crate) const fn new() -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(INCOMPLETE),
queue: AtomicUsize::new(INCOMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(None),
}
}

pub(crate) const fn with_value(value: T) -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(COMPLETE),
queue: AtomicUsize::new(COMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(Some(value)),
}
Expand All @@ -84,7 +64,7 @@ impl<T> OnceCell<T> {
// operations visible to us, and, this being a fast path, weaker
// ordering helps with performance. This `Acquire` synchronizes with
// `SeqCst` operations on the slow path.
self.state_and_queue.load(Ordering::Acquire) == COMPLETE
self.queue.load(Ordering::Acquire) == COMPLETE
}

/// Safety: synchronizes with store to value via SeqCst read from state,
Expand All @@ -98,22 +78,30 @@ impl<T> OnceCell<T> {
let mut f = Some(f);
let mut res: Result<(), E> = Ok(());
let slot: *mut Option<T> = self.value.get();
initialize_inner(&self.state_and_queue, &mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Ok(value) => {
unsafe { *slot = Some(value) };
true
}
Err(err) => {
res = Err(err);
false
initialize_or_wait(
&self.queue,
Some(&mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Ok(value) => {
unsafe { *slot = Some(value) };
true
}
Err(err) => {
res = Err(err);
false
}
}
}
});
}),
);
res
}

#[cold]
pub(crate) fn wait(&self) {
initialize_or_wait(&self.queue, None);
}

/// Get the reference to the underlying value, without checking if the cell
/// is initialized.
///
Expand Down Expand Up @@ -152,67 +140,111 @@ impl<T> OnceCell<T> {
}
}

// Corresponds to `std::sync::Once::call_inner`
// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

/// Representation of a node in the linked list of waiters in the RUNNING state.
/// A waiters is stored on the stack of the waiting threads.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

/// Drains and notifies the queue of waiters on drop.
struct Guard<'a> {
queue: &'a AtomicUsize,
new_queue: usize,
}

impl Drop for Guard<'_> {
fn drop(&mut self) {
let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);

assert_eq!(queue & STATE_MASK, RUNNING);

unsafe {
let mut waiter = (queue & !STATE_MASK) as *const Waiter;
while !waiter.is_null() {
let next = (*waiter).next;
let thread = (*waiter).thread.take().unwrap();
(*waiter).signaled.store(true, Ordering::Release);
waiter = next;
thread.unpark();
}
}
}
}

// Corresponds to `std::sync::Once::call_inner`.
//
// Originally copied from std, but since modified to remove poisoning and to
// support wait.
//
// Note: this is intentionally monomorphic
#[inline(never)]
fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool {
let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire);
fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) {
let mut curr_queue = queue.load(Ordering::Acquire);

loop {
match state_and_queue {
COMPLETE => return true,
INCOMPLETE => {
let exchange = my_state_and_queue.compare_exchange(
state_and_queue,
RUNNING,
let curr_state = curr_queue & STATE_MASK;
match (curr_state, &mut init) {
(COMPLETE, _) => return,
(INCOMPLETE, Some(init)) => {
let exchange = queue.compare_exchange(
curr_queue,
(curr_queue & !STATE_MASK) | RUNNING,
Ordering::Acquire,
Ordering::Acquire,
);
if let Err(old) = exchange {
state_and_queue = old;
if let Err(new_queue) = exchange {
curr_queue = new_queue;
continue;
}
let mut waiter_queue = WaiterQueue {
state_and_queue: my_state_and_queue,
set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED`
};
let success = init();

// Difference, std always uses `COMPLETE`
waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE };
return success;
let mut guard = Guard { queue, new_queue: INCOMPLETE };
if init() {
guard.new_queue = COMPLETE;
}
return;
}
_ => {
assert!(state_and_queue & STATE_MASK == RUNNING);
wait(&my_state_and_queue, state_and_queue);
state_and_queue = my_state_and_queue.load(Ordering::Acquire);
(INCOMPLETE, None) | (RUNNING, _) => {
wait(&queue, curr_queue);
curr_queue = queue.load(Ordering::Acquire);
}
_ => debug_assert!(false),
}
}
}

// Copy-pasted from std exactly.
fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
let curr_state = curr_queue & STATE_MASK;
loop {
if current_state & STATE_MASK != RUNNING {
return;
}

let node = Waiter {
thread: Cell::new(Some(thread::current())),
signaled: AtomicBool::new(false),
next: (current_state & !STATE_MASK) as *const Waiter,
next: (curr_queue & !STATE_MASK) as *const Waiter,
};
let me = &node as *const Waiter as usize;

let exchange = state_and_queue.compare_exchange(
current_state,
me | RUNNING,
let exchange = queue.compare_exchange(
curr_queue,
me | curr_state,
Ordering::Release,
Ordering::Relaxed,
);
if let Err(old) = exchange {
current_state = old;
if let Err(new_queue) = exchange {
if new_queue & STATE_MASK != curr_state {
return;
}
curr_queue = new_queue;
continue;
}

Expand All @@ -223,27 +255,6 @@ fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
}
}

// Copy-pasted from std exactly.
impl Drop for WaiterQueue<'_> {
fn drop(&mut self) {
let state_and_queue =
self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel);

assert_eq!(state_and_queue & STATE_MASK, RUNNING);

unsafe {
let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter;
while !queue.is_null() {
let next = (*queue).next;
let thread = (*queue).thread.replace(None).unwrap();
(*queue).signaled.store(true, Ordering::Release);
queue = next;
thread.unpark();
}
}
}
}

// These test are snatched from std as well.
#[cfg(test)]
mod tests {
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Expand Up @@ -875,6 +875,17 @@ pub mod sync {
}
}

/// Blocks until the value is set by another thread.
pub fn wait(&self) -> &T {
if !self.0.is_initialized() {
self.0.wait()
}
debug_assert!(self.0.is_initialized());
// Safe b/c of the wait call above and the fact that we didn't
// relinquish our borrow.
unsafe { self.get_unchecked() }
}

/// Gets the mutable reference to the underlying value.
///
/// Returns `None` if the cell is empty.
Expand Down
21 changes: 17 additions & 4 deletions tests/it.rs
Expand Up @@ -319,6 +319,17 @@ mod sync {
assert_eq!(cell.get(), Some(&"hello".to_string()));
}

#[test]
fn wait() {
let cell: OnceCell<String> = OnceCell::new();
scope(|s| {
s.spawn(|_| cell.set("hello".to_string()));
let greeting = cell.wait();
assert_eq!(greeting, "hello")
})
.unwrap();
}

#[test]
#[cfg_attr(miri, ignore)] // miri doesn't support Barrier
fn get_or_init_stress() {
Expand All @@ -329,16 +340,18 @@ mod sync {
.take(n_cells)
.collect();
scope(|s| {
for _ in 0..n_threads {
s.spawn(|_| {
for t in 0..n_threads {
let cells = &cells;
s.spawn(move |_| {
for (i, (b, s)) in cells.iter().enumerate() {
b.wait();
let j = s.get_or_init(|| i);
let j = if t % 2 == 0 { s.wait() } else { s.get_or_init(|| i) };
assert_eq!(*j, i);
}
});
}
}).unwrap();
})
.unwrap();
}

#[test]
Expand Down

0 comments on commit c935154

Please sign in to comment.