Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add OnceCell::wait #177

Merged
merged 1 commit into from May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -2,7 +2,7 @@

## Unreleased

-
- Add `OnceCell::wait`, a blocking variant of `get`.

## 1.11

Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "once_cell"
version = "1.11.0"
version = "1.12.0-pre.1"
authors = ["Aleksey Kladov <aleksey.kladov@gmail.com>"]
license = "MIT OR Apache-2.0"
edition = "2018"
Expand Down
15 changes: 15 additions & 0 deletions src/imp_pl.rs
Expand Up @@ -77,6 +77,21 @@ impl<T> OnceCell<T> {
res
}

#[cold]
pub(crate) fn wait(&self) {
let key = &self.state as *const _ as usize;
unsafe {
parking_lot_core::park(
key,
|| self.state.load(Ordering::Acquire) != COMPLETE,
|| (),
|_, _| (),
parking_lot_core::DEFAULT_PARK_TOKEN,
None,
);
}
}
matklad marked this conversation as resolved.
Show resolved Hide resolved

/// Get the reference to the underlying value, without checking if the cell
/// is initialized.
///
Expand Down
213 changes: 112 additions & 101 deletions src/imp_std.rs
Expand Up @@ -16,9 +16,15 @@ use crate::take_unchecked;

#[derive(Debug)]
pub(crate) struct OnceCell<T> {
// This `state` word is actually an encoded version of just a pointer to a
// `Waiter`, so we add the `PhantomData` appropriately.
state_and_queue: AtomicUsize,
// This `queue` field is the core of the implementation. It encodes two
// pieces of information:
//
// * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
// * Linked list of threads waiting for the current cell.
//
// State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
// allow waiters.
queue: AtomicUsize,
_marker: PhantomData<*mut Waiter>,
value: UnsafeCell<Option<T>>,
}
Expand All @@ -34,44 +40,18 @@ unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}

// Three states that a OnceCell can be in, encoded into the lower bits of `state` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

// Representation of a node in the linked list of waiters in the RUNNING state.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

// Head of a linked list of waiters.
// Every node is a struct on the stack of a waiting thread.
// Will wake up the waiters when it gets dropped, i.e. also on panic.
struct WaiterQueue<'a> {
state_and_queue: &'a AtomicUsize,
set_state_on_drop_to: usize,
}

impl<T> OnceCell<T> {
pub(crate) const fn new() -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(INCOMPLETE),
queue: AtomicUsize::new(INCOMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(None),
}
}

pub(crate) const fn with_value(value: T) -> OnceCell<T> {
OnceCell {
state_and_queue: AtomicUsize::new(COMPLETE),
queue: AtomicUsize::new(COMPLETE),
_marker: PhantomData,
value: UnsafeCell::new(Some(value)),
}
Expand All @@ -84,7 +64,7 @@ impl<T> OnceCell<T> {
// operations visible to us, and, this being a fast path, weaker
// ordering helps with performance. This `Acquire` synchronizes with
// `SeqCst` operations on the slow path.
self.state_and_queue.load(Ordering::Acquire) == COMPLETE
self.queue.load(Ordering::Acquire) == COMPLETE
}

/// Safety: synchronizes with store to value via SeqCst read from state,
Expand All @@ -98,22 +78,30 @@ impl<T> OnceCell<T> {
let mut f = Some(f);
let mut res: Result<(), E> = Ok(());
let slot: *mut Option<T> = self.value.get();
initialize_inner(&self.state_and_queue, &mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Ok(value) => {
unsafe { *slot = Some(value) };
true
}
Err(err) => {
res = Err(err);
false
initialize_or_wait(
&self.queue,
Some(&mut || {
let f = unsafe { take_unchecked(&mut f) };
match f() {
Ok(value) => {
unsafe { *slot = Some(value) };
true
}
Err(err) => {
res = Err(err);
false
}
}
}
});
}),
);
res
}

#[cold]
pub(crate) fn wait(&self) {
initialize_or_wait(&self.queue, None);
}

/// Get the reference to the underlying value, without checking if the cell
/// is initialized.
///
Expand Down Expand Up @@ -152,67 +140,111 @@ impl<T> OnceCell<T> {
}
}

// Corresponds to `std::sync::Once::call_inner`
// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
// the OnceCell structure.
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
const STATE_MASK: usize = 0x3;

/// Representation of a node in the linked list of waiters in the RUNNING state.
/// A waiters is stored on the stack of the waiting threads.
#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
}

/// Drains and notifies the queue of waiters on drop.
struct Guard<'a> {
queue: &'a AtomicUsize,
new_queue: usize,
}

impl Drop for Guard<'_> {
fn drop(&mut self) {
let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);

assert_eq!(queue & STATE_MASK, RUNNING);

unsafe {
let mut waiter = (queue & !STATE_MASK) as *const Waiter;
while !waiter.is_null() {
let next = (*waiter).next;
let thread = (*waiter).thread.take().unwrap();
(*waiter).signaled.store(true, Ordering::Release);
waiter = next;
thread.unpark();
}
}
}
}

// Corresponds to `std::sync::Once::call_inner`.
//
// Originally copied from std, but since modified to remove poisoning and to
// support wait.
//
// Note: this is intentionally monomorphic
#[inline(never)]
fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool {
let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire);
fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) {
let mut curr_queue = queue.load(Ordering::Acquire);

loop {
match state_and_queue {
COMPLETE => return true,
INCOMPLETE => {
let exchange = my_state_and_queue.compare_exchange(
state_and_queue,
RUNNING,
let curr_state = curr_queue & STATE_MASK;
match (curr_state, &mut init) {
(COMPLETE, _) => return,
(INCOMPLETE, Some(init)) => {
let exchange = queue.compare_exchange(
curr_queue,
(curr_queue & !STATE_MASK) | RUNNING,
Ordering::Acquire,
Ordering::Acquire,
);
if let Err(old) = exchange {
state_and_queue = old;
if let Err(new_queue) = exchange {
curr_queue = new_queue;
continue;
}
let mut waiter_queue = WaiterQueue {
state_and_queue: my_state_and_queue,
set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED`
};
let success = init();

// Difference, std always uses `COMPLETE`
waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE };
return success;
let mut guard = Guard { queue, new_queue: INCOMPLETE };
if init() {
guard.new_queue = COMPLETE;
}
return;
}
_ => {
assert!(state_and_queue & STATE_MASK == RUNNING);
wait(&my_state_and_queue, state_and_queue);
state_and_queue = my_state_and_queue.load(Ordering::Acquire);
(INCOMPLETE, None) | (RUNNING, _) => {
wait(&queue, curr_queue);
curr_queue = queue.load(Ordering::Acquire);
}
_ => debug_assert!(false),
}
}
}

// Copy-pasted from std exactly.
fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
let curr_state = curr_queue & STATE_MASK;
loop {
if current_state & STATE_MASK != RUNNING {
return;
}

let node = Waiter {
thread: Cell::new(Some(thread::current())),
signaled: AtomicBool::new(false),
next: (current_state & !STATE_MASK) as *const Waiter,
next: (curr_queue & !STATE_MASK) as *const Waiter,
};
let me = &node as *const Waiter as usize;

let exchange = state_and_queue.compare_exchange(
current_state,
me | RUNNING,
let exchange = queue.compare_exchange(
curr_queue,
me | curr_state,
Ordering::Release,
Ordering::Relaxed,
);
if let Err(old) = exchange {
current_state = old;
if let Err(new_queue) = exchange {
if new_queue & STATE_MASK != curr_state {
return;
}
curr_queue = new_queue;
continue;
}

Expand All @@ -223,27 +255,6 @@ fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
}
}

// Copy-pasted from std exactly.
impl Drop for WaiterQueue<'_> {
fn drop(&mut self) {
let state_and_queue =
self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel);

assert_eq!(state_and_queue & STATE_MASK, RUNNING);

unsafe {
let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter;
while !queue.is_null() {
let next = (*queue).next;
let thread = (*queue).thread.replace(None).unwrap();
(*queue).signaled.store(true, Ordering::Release);
queue = next;
thread.unpark();
}
}
}
}

// These test are snatched from std as well.
#[cfg(test)]
mod tests {
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Expand Up @@ -875,6 +875,17 @@ pub mod sync {
}
}

/// Blocks until the value is set by another thread.
pub fn wait(&self) -> &T {
if !self.0.is_initialized() {
self.0.wait()
}
debug_assert!(self.0.is_initialized());
// Safe b/c of the wait call above and the fact that we didn't
// relinquish our borrow.
unsafe { self.get_unchecked() }
}

/// Gets the mutable reference to the underlying value.
///
/// Returns `None` if the cell is empty.
Expand Down
21 changes: 17 additions & 4 deletions tests/it.rs
Expand Up @@ -319,6 +319,17 @@ mod sync {
assert_eq!(cell.get(), Some(&"hello".to_string()));
}

#[test]
fn wait() {
let cell: OnceCell<String> = OnceCell::new();
scope(|s| {
s.spawn(|_| cell.set("hello".to_string()));
let greeting = cell.wait();
assert_eq!(greeting, "hello")
})
.unwrap();
}

#[test]
#[cfg_attr(miri, ignore)] // miri doesn't support Barrier
fn get_or_init_stress() {
Expand All @@ -329,16 +340,18 @@ mod sync {
.take(n_cells)
.collect();
scope(|s| {
for _ in 0..n_threads {
s.spawn(|_| {
for t in 0..n_threads {
let cells = &cells;
s.spawn(move |_| {
for (i, (b, s)) in cells.iter().enumerate() {
b.wait();
let j = s.get_or_init(|| i);
let j = if t % 2 == 0 { s.wait() } else { s.get_or_init(|| i) };
assert_eq!(*j, i);
}
});
}
}).unwrap();
})
.unwrap();
}

#[test]
Expand Down