diff --git a/Cargo.toml b/Cargo.toml index 004aa77..1ad618f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,15 +32,19 @@ parking_lot_core = { version = "0.9.3", optional = true, default_features = fals # and make sure you understand all the implications atomic-polyfill = { version = "1", optional = true } +# Uses `critical-section` to implement `once_cell::sync::OnceCell`. +critical_section = { package = "critical-section", version = "1", optional = true } + [dev-dependencies] lazy_static = "1.0.0" crossbeam-utils = "0.8.7" regex = "1.2.0" +critical_section = { package = "critical-section", version = "1.1.1", features = ["std"] } [features] default = ["std"] # Enables `once_cell::sync` module. -std = ["alloc"] +std = ["alloc", "sync"] # Enables `once_cell::race::OnceBox` type. alloc = ["race"] # Enables `once_cell::race` module. @@ -49,8 +53,12 @@ race = [] # At the moment, this feature is unused. unstable = [] +sync = [] + parking_lot = ["parking_lot_core"] +critical-section = ["critical_section", "atomic-polyfill", "sync"] + [[example]] name = "bench" required-features = ["std"] diff --git a/src/imp_cs.rs b/src/imp_cs.rs new file mode 100644 index 0000000..2548f58 --- /dev/null +++ b/src/imp_cs.rs @@ -0,0 +1,78 @@ +use core::panic::{RefUnwindSafe, UnwindSafe}; + +use atomic_polyfill::{AtomicBool, Ordering}; +use critical_section::{CriticalSection, Mutex}; + +use crate::unsync; + +pub(crate) struct OnceCell { + initialized: AtomicBool, + // Use `unsync::OnceCell` internally since `Mutex` does not provide + // interior mutability and to be able to re-use `get_or_try_init`. + value: Mutex>, +} + +// Why do we need `T: Send`? +// Thread A creates a `OnceCell` and shares it with +// scoped thread B, which fills the cell, which is +// then destroyed by A. That is, destructor observes +// a sent value. +unsafe impl Sync for OnceCell {} +unsafe impl Send for OnceCell {} + +impl RefUnwindSafe for OnceCell {} +impl UnwindSafe for OnceCell {} + +impl OnceCell { + pub(crate) const fn new() -> OnceCell { + OnceCell { initialized: AtomicBool::new(false), value: Mutex::new(unsync::OnceCell::new()) } + } + + pub(crate) const fn with_value(value: T) -> OnceCell { + OnceCell { + initialized: AtomicBool::new(true), + value: Mutex::new(unsync::OnceCell::with_value(value)), + } + } + + #[inline] + pub(crate) fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::Acquire) + } + + #[cold] + pub(crate) fn initialize(&self, f: F) -> Result<(), E> + where + F: FnOnce() -> Result, + { + critical_section::with(|cs| { + let cell = self.value.borrow(cs); + cell.get_or_try_init(f).map(|_| { + self.initialized.store(true, Ordering::Release); + }) + }) + } + + /// Get the reference to the underlying value, without checking if the cell + /// is initialized. + /// + /// # Safety + /// + /// Caller must ensure that the cell is in initialized state, and that + /// the contents are acquired by (synchronized to) this thread. + pub(crate) unsafe fn get_unchecked(&self) -> &T { + debug_assert!(self.is_initialized()); + // SAFETY: The caller ensures that the value is initialized and access synchronized. + self.value.borrow(CriticalSection::new()).get_unchecked() + } + + #[inline] + pub(crate) fn get_mut(&mut self) -> Option<&mut T> { + self.value.get_mut().get_mut() + } + + #[inline] + pub(crate) fn into_inner(self) -> Option { + self.value.into_inner().into_inner() + } +} diff --git a/src/imp_pl.rs b/src/imp_pl.rs index d80ca5e..b723587 100644 --- a/src/imp_pl.rs +++ b/src/imp_pl.rs @@ -1,6 +1,5 @@ use std::{ cell::UnsafeCell, - hint, panic::{RefUnwindSafe, UnwindSafe}, sync::atomic::{AtomicU8, Ordering}, }; @@ -101,15 +100,8 @@ impl OnceCell { /// the contents are acquired by (synchronized to) this thread. pub(crate) unsafe fn get_unchecked(&self) -> &T { debug_assert!(self.is_initialized()); - let slot: &Option = &*self.value.get(); - match slot { - Some(value) => value, - // This unsafe does improve performance, see `examples/bench`. - None => { - debug_assert!(false); - hint::unreachable_unchecked() - } - } + let slot = &*self.value.get(); + crate::unwrap_unchecked(slot.as_ref()) } /// Gets the mutable reference to the underlying value. diff --git a/src/imp_std.rs b/src/imp_std.rs index 4d5b5fd..f828d45 100644 --- a/src/imp_std.rs +++ b/src/imp_std.rs @@ -5,15 +5,12 @@ use std::{ cell::{Cell, UnsafeCell}, - hint::unreachable_unchecked, marker::PhantomData, panic::{RefUnwindSafe, UnwindSafe}, sync::atomic::{AtomicBool, AtomicPtr, Ordering}, thread::{self, Thread}, }; -use crate::take_unchecked; - #[derive(Debug)] pub(crate) struct OnceCell { // This `queue` field is the core of the implementation. It encodes two @@ -81,7 +78,7 @@ impl OnceCell { initialize_or_wait( &self.queue, Some(&mut || { - let f = unsafe { take_unchecked(&mut f) }; + let f = unsafe { crate::take_unchecked(&mut f) }; match f() { Ok(value) => { unsafe { *slot = Some(value) }; @@ -111,15 +108,8 @@ impl OnceCell { /// the contents are acquired by (synchronized to) this thread. pub(crate) unsafe fn get_unchecked(&self) -> &T { debug_assert!(self.is_initialized()); - let slot: &Option = &*self.value.get(); - match slot { - Some(value) => value, - // This unsafe does improve performance, see `examples/bench`. - None => { - debug_assert!(false); - unreachable_unchecked() - } - } + let slot = &*self.value.get(); + crate::unwrap_unchecked(slot.as_ref()) } /// Gets the mutable reference to the underlying value. diff --git a/src/lib.rs b/src/lib.rs index 6de1e3e..f44bc4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -332,13 +332,15 @@ #[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "std")] -#[cfg(feature = "parking_lot")] +#[cfg(all(feature = "critical-section", not(feature = "std")))] +#[path = "imp_cs.rs"] +mod imp; + +#[cfg(all(feature = "std", feature = "parking_lot"))] #[path = "imp_pl.rs"] mod imp; -#[cfg(feature = "std")] -#[cfg(not(feature = "parking_lot"))] +#[cfg(all(feature = "std", not(feature = "parking_lot")))] #[path = "imp_std.rs"] mod imp; @@ -346,11 +348,13 @@ mod imp; pub mod unsync { use core::{ cell::{Cell, UnsafeCell}, - fmt, hint, mem, + fmt, mem, ops::{Deref, DerefMut}, panic::{RefUnwindSafe, UnwindSafe}, }; + use super::unwrap_unchecked; + /// A cell which can be written to only once. It is not thread safe. /// /// Unlike [`std::cell::RefCell`], a `OnceCell` provides simple `&` @@ -442,6 +446,7 @@ pub mod unsync { /// Gets a reference to the underlying value. /// /// Returns `None` if the cell is empty. + #[inline] pub fn get(&self) -> Option<&T> { // Safe due to `inner`'s invariant unsafe { &*self.inner.get() }.as_ref() @@ -463,11 +468,24 @@ pub mod unsync { /// *cell.get_mut().unwrap() = 93; /// assert_eq!(cell.get(), Some(&93)); /// ``` + #[inline] pub fn get_mut(&mut self) -> Option<&mut T> { // Safe because we have unique access unsafe { &mut *self.inner.get() }.as_mut() } + /// Get the reference to the underlying value, without checking if the + /// cell is initialized. + /// + /// # Safety + /// + /// Caller must ensure that the cell is in initialized state. + #[cfg(feature = "critical-section")] + #[inline] + pub(crate) unsafe fn get_unchecked(&self) -> &T { + crate::unwrap_unchecked(self.get()) + } + /// Sets the contents of this cell to `value`. /// /// Returns `Ok(())` if the cell was empty and `Err(value)` if it was @@ -510,16 +528,14 @@ pub mod unsync { if let Some(old) = self.get() { return Err((old, value)); } + let slot = unsafe { &mut *self.inner.get() }; // This is the only place where we set the slot, no races // due to reentrancy/concurrency are possible, and we've // checked that slot is currently `None`, so this write // maintains the `inner`'s invariant. *slot = Some(value); - Ok(match &*slot { - Some(value) => value, - None => unsafe { hint::unreachable_unchecked() }, - }) + Ok(unsafe { unwrap_unchecked(slot.as_ref()) }) } /// Gets the contents of the cell, initializing it with `f` @@ -592,7 +608,7 @@ pub mod unsync { // `assert`, while keeping `set/get` would be sound, but it seems // better to panic, rather than to silently use an old value. assert!(self.set(val).is_ok(), "reentrant init"); - Ok(self.get().unwrap()) + Ok(unsafe { unwrap_unchecked(self.get()) }) } /// Takes the value out of this `OnceCell`, moving it back to an uninitialized state. @@ -814,16 +830,16 @@ pub mod unsync { } /// Thread-safe, blocking version of `OnceCell`. -#[cfg(feature = "std")] +#[cfg(feature = "sync")] pub mod sync { - use std::{ + use core::{ cell::Cell, fmt, mem, ops::{Deref, DerefMut}, panic::RefUnwindSafe, }; - use crate::{imp::OnceCell as Imp, take_unchecked}; + use super::{imp::OnceCell as Imp, take_unchecked}; /// A thread-safe cell which can be written to only once. /// @@ -942,8 +958,9 @@ pub mod sync { /// // Will return 92, but might block until the other thread does `.set`. /// let value: &u32 = cell.wait(); /// assert_eq!(*value, 92); - /// t.join().unwrap();; + /// t.join().unwrap(); /// ``` + #[cfg(not(feature = "critical-section"))] pub fn wait(&self) -> &T { if !self.0.is_initialized() { self.0.wait() @@ -969,6 +986,7 @@ pub mod sync { /// cell.set(92).unwrap(); /// cell = OnceCell::new(); /// ``` + #[inline] pub fn get_mut(&mut self) -> Option<&mut T> { self.0.get_mut() } @@ -980,6 +998,7 @@ pub mod sync { /// /// Caller must ensure that the cell is in initialized state, and that /// the contents are acquired by (synchronized to) this thread. + #[inline] pub unsafe fn get_unchecked(&self) -> &T { self.0.get_unchecked() } @@ -1109,6 +1128,7 @@ pub mod sync { if let Some(value) = self.get() { return Ok(value); } + self.0.initialize(f)?; // Safe b/c value is initialized. @@ -1164,6 +1184,7 @@ pub mod sync { /// cell.set("hello".to_string()).unwrap(); /// assert_eq!(cell.into_inner(), Some("hello".to_string())); /// ``` + #[inline] pub fn into_inner(self) -> Option { self.0.into_inner() } @@ -1356,13 +1377,19 @@ pub mod sync { #[cfg(feature = "race")] pub mod race; -#[cfg(feature = "std")] +#[cfg(feature = "sync")] +#[inline] unsafe fn take_unchecked(val: &mut Option) -> T { - match val.take() { - Some(it) => it, + unwrap_unchecked(val.take()) +} + +#[inline] +unsafe fn unwrap_unchecked(val: Option) -> T { + match val { + Some(value) => value, None => { debug_assert!(false); - std::hint::unreachable_unchecked() + core::hint::unreachable_unchecked() } } } diff --git a/src/race.rs b/src/race.rs index e83e0b9..28acf62 100644 --- a/src/race.rs +++ b/src/race.rs @@ -165,6 +165,7 @@ impl OnceBool { fn from_usize(value: NonZeroUsize) -> bool { value.get() == 1 } + #[inline] fn to_usize(value: bool) -> NonZeroUsize { unsafe { NonZeroUsize::new_unchecked(if value { 1 } else { 2 }) } diff --git a/tests/it.rs b/tests/it.rs index 410b93b..118be00 100644 --- a/tests/it.rs +++ b/tests/it.rs @@ -249,10 +249,16 @@ mod unsync { } } -#[cfg(feature = "std")] +#[cfg(feature = "sync")] mod sync { use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; + #[cfg(not(feature = "critical-section"))] + use std::sync::Barrier; + + #[cfg(feature = "critical-section")] + use core::cell::Cell; + use crossbeam_utils::thread::scope; use once_cell::sync::{Lazy, OnceCell}; @@ -354,6 +360,7 @@ mod sync { assert_eq!(cell.get(), Some(&"hello".to_string())); } + #[cfg(not(feature = "critical-section"))] #[test] fn wait() { let cell: OnceCell = OnceCell::new(); @@ -365,9 +372,9 @@ mod sync { .unwrap(); } + #[cfg(not(feature = "critical-section"))] #[test] fn get_or_init_stress() { - use std::sync::Barrier; let n_threads = if cfg!(miri) { 30 } else { 1_000 }; let n_cells = if cfg!(miri) { 30 } else { 1_000 }; let cells: Vec<_> = std::iter::repeat_with(|| (Barrier::new(n_threads), OnceCell::new())) @@ -430,6 +437,7 @@ mod sync { #[test] #[cfg_attr(miri, ignore)] // miri doesn't support processes + #[cfg(not(feature = "critical-section"))] fn reentrant_init() { let examples_dir = { let mut exe = std::env::current_exe().unwrap(); @@ -457,6 +465,20 @@ mod sync { } } + #[cfg(feature = "critical-section")] + #[test] + #[should_panic(expected = "reentrant init")] + fn reentrant_init() { + let x: OnceCell> = OnceCell::new(); + let dangling_ref: Cell> = Cell::new(None); + x.get_or_init(|| { + let r = x.get_or_init(|| Box::new(92)); + dangling_ref.set(Some(r)); + Box::new(62) + }); + eprintln!("use after free: {:?}", dangling_ref.get().unwrap()); + } + #[test] fn lazy_new() { let called = AtomicUsize::new(0); @@ -636,10 +658,9 @@ mod sync { } } + #[cfg(not(feature = "critical-section"))] #[test] fn get_does_not_block() { - use std::sync::Barrier; - let cell = OnceCell::new(); let barrier = Barrier::new(2); scope(|scope| { @@ -671,12 +692,11 @@ mod sync { #[cfg(feature = "race")] mod race { + #[cfg(not(feature = "critical-section"))] + use std::sync::Barrier; use std::{ num::NonZeroUsize, - sync::{ - atomic::{AtomicUsize, Ordering::SeqCst}, - Barrier, - }, + sync::atomic::{AtomicUsize, Ordering::SeqCst}, }; use crossbeam_utils::thread::scope; @@ -728,6 +748,7 @@ mod race { assert_eq!(cell.get(), Some(val1)); } + #[cfg(not(feature = "critical-section"))] #[test] fn once_non_zero_usize_first_wins() { let val1 = NonZeroUsize::new(92).unwrap(); @@ -807,12 +828,16 @@ mod race { #[cfg(all(feature = "race", feature = "alloc"))] mod race_once_box { + #[cfg(not(feature = "critical-section"))] + use std::sync::Barrier; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, - Arc, Barrier, + Arc, }; + #[cfg(not(feature = "critical-section"))] use crossbeam_utils::thread::scope; + use once_cell::race::OnceBox; #[derive(Default)] @@ -842,6 +867,7 @@ mod race_once_box { } } + #[cfg(not(feature = "critical-section"))] #[test] fn once_box_smoke_test() { let heap = Heap::default(); @@ -896,6 +922,7 @@ mod race_once_box { assert_eq!(heap.total(), 0); } + #[cfg(not(feature = "critical-section"))] #[test] fn once_box_first_wins() { let cell = OnceBox::new(); diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 654efa9..aaae075 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -32,6 +32,9 @@ fn main() -> xshell::Result<()> { // Skip doctests for no_std tests as those don't work cmd!(sh, "cargo test --no-default-features --features unstable --test it").run()?; cmd!(sh, "cargo test --no-default-features --features unstable,alloc --test it").run()?; + + cmd!(sh, "cargo test --no-default-features --features critical-section").run()?; + cmd!(sh, "cargo test --no-default-features --features critical-section --release").run()?; } {