diff --git a/src/imp_pl.rs b/src/imp_pl.rs index a9bae85..b562a20 100644 --- a/src/imp_pl.rs +++ b/src/imp_pl.rs @@ -86,6 +86,10 @@ impl OnceCell { res } + pub(crate) fn wait(&self) { + unimplemented!() + } + /// Get the reference to the underlying value, without checking if the cell /// is initialized. /// diff --git a/src/imp_std.rs b/src/imp_std.rs index 9a1c83d..24066f1 100644 --- a/src/imp_std.rs +++ b/src/imp_std.rs @@ -98,22 +98,30 @@ impl OnceCell { let mut f = Some(f); let mut res: Result<(), E> = Ok(()); let slot: *mut Option = self.value.get(); - initialize_inner(&self.state_and_queue, &mut || { - let f = unsafe { take_unchecked(&mut f) }; - match f() { - Ok(value) => { - unsafe { *slot = Some(value) }; - true + initialize_or_wait( + &self.state_and_queue, + Some(&mut || { + let f = unsafe { take_unchecked(&mut f) }; + match f() { + Ok(value) => { + unsafe { *slot = Some(value) }; + true + } + Err(err) => { + res = Err(err); + false + } } - Err(err) => { - res = Err(err); - false - } - } - }); + }), + ); res } + #[cold] + pub(crate) fn wait(&self) { + initialize_or_wait(&self.state_and_queue, None); + } + /// Get the reference to the underlying value, without checking if the cell /// is initialized. /// @@ -155,64 +163,63 @@ impl OnceCell { // Corresponds to `std::sync::Once::call_inner` // Note: this is intentionally monomorphic #[inline(never)] -fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool { - let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire); +fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) { + let mut curr_queue = queue.load(Ordering::Acquire); loop { - match state_and_queue { - COMPLETE => return true, - INCOMPLETE => { - let exchange = my_state_and_queue.compare_exchange( - state_and_queue, - RUNNING, + let curr_state = curr_queue & STATE_MASK; + match (curr_state, &mut init) { + (COMPLETE, _) => return, + (INCOMPLETE, Some(init)) => { + let exchange = queue.compare_exchange( + curr_queue, + (curr_queue & !STATE_MASK) | RUNNING, Ordering::Acquire, Ordering::Acquire, ); if let Err(old) = exchange { - state_and_queue = old; + curr_queue = old; continue; } let mut waiter_queue = WaiterQueue { - state_and_queue: my_state_and_queue, + state_and_queue: queue, set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED` }; - let success = init(); - - // Difference, std always uses `COMPLETE` - waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE }; - return success; + if init() { + waiter_queue.set_state_on_drop_to = COMPLETE; + } + return; } - _ => { - assert!(state_and_queue & STATE_MASK == RUNNING); - wait(&my_state_and_queue, state_and_queue); - state_and_queue = my_state_and_queue.load(Ordering::Acquire); + (INCOMPLETE, None) | (RUNNING, _) => { + wait(&queue, curr_queue); + curr_queue = queue.load(Ordering::Acquire); } + _ => unreachable!(), } } } -// Copy-pasted from std exactly. -fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) { +fn wait(queue: &AtomicUsize, mut curr_queue: usize) { + let curr_state = curr_queue & STATE_MASK; loop { - if current_state & STATE_MASK != RUNNING { - return; - } - let node = Waiter { thread: Cell::new(Some(thread::current())), signaled: AtomicBool::new(false), - next: (current_state & !STATE_MASK) as *const Waiter, + next: (curr_queue & !STATE_MASK) as *const Waiter, }; let me = &node as *const Waiter as usize; - let exchange = state_and_queue.compare_exchange( - current_state, - me | RUNNING, + let exchange = queue.compare_exchange( + curr_queue, + me | curr_state, Ordering::Release, Ordering::Relaxed, ); if let Err(old) = exchange { - current_state = old; + if old & STATE_MASK != curr_state { + return; + } + curr_queue = old; continue; } diff --git a/src/lib.rs b/src/lib.rs index b344251..0d1a7d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -870,6 +870,17 @@ pub mod sync { } } + /// Blocks until the value is set by another thread. + pub fn wait(&self) -> &T { + if !self.0.is_initialized() { + self.0.wait() + } + debug_assert!(self.0.is_initialized()); + // Safe b/c of the wait call above and the fact that we didn't + // relinquish our borrow. + unsafe { self.get_unchecked() } + } + /// Gets the mutable reference to the underlying value. /// /// Returns `None` if the cell is empty. diff --git a/tests/it.rs b/tests/it.rs index 4d6efdc..118ae96 100644 --- a/tests/it.rs +++ b/tests/it.rs @@ -319,6 +319,22 @@ mod sync { assert_eq!(cell.get(), Some(&"hello".to_string())); } + #[test] + fn wait() { + let x = OnceCell::new(); + + scope(|s| { + let w1 = s.spawn(|_| x.wait()); + s.spawn(|_| x.set("hello".to_string())); + let w2 = s.spawn(|_| x.wait()); + s.spawn(|_| x.set("world".to_string())); + let w1 = w1.join().unwrap(); + let w2 = w2.join().unwrap(); + assert_eq!(w1, w2) + }) + .unwrap(); + } + #[test] fn from_impl() { assert_eq!(OnceCell::from("value").get(), Some(&"value"));