diff --git a/src/race.rs b/src/race.rs index dff5847..ee3d51a 100644 --- a/src/race.rs +++ b/src/race.rs @@ -24,10 +24,11 @@ use atomic_polyfill as atomic; #[cfg(not(feature = "critical-section"))] use core::sync::atomic; -use atomic::{AtomicUsize, Ordering}; +use atomic::{AtomicPtr, AtomicUsize, Ordering}; use core::cell::UnsafeCell; use core::marker::PhantomData; use core::num::NonZeroUsize; +use core::ptr; /// A thread-safe cell which can be written to only once. #[derive(Default, Debug)] @@ -176,7 +177,7 @@ impl OnceBool { /// A thread-safe cell which can be written to only once. pub struct OnceRef<'a, T> { - inner: OnceNonZeroUsize, + inner: AtomicPtr, ghost: PhantomData>, } @@ -198,12 +199,13 @@ impl<'a, T> Default for OnceRef<'a, T> { impl<'a, T> OnceRef<'a, T> { /// Creates a new empty cell. pub const fn new() -> OnceRef<'a, T> { - OnceRef { inner: OnceNonZeroUsize::new(), ghost: PhantomData } + OnceRef { inner: AtomicPtr::new(ptr::null_mut()), ghost: PhantomData } } /// Gets a reference to the underlying value. pub fn get(&self) -> Option<&'a T> { - self.inner.get().map(|ptr| unsafe { &*(ptr.get() as *const T) }) + let ptr = self.inner.load(Ordering::Acquire); + unsafe { ptr.as_ref() } } /// Sets the contents of this cell to `value`. @@ -211,8 +213,13 @@ impl<'a, T> OnceRef<'a, T> { /// Returns `Ok(())` if the cell was empty and `Err(value)` if it was /// full. pub fn set(&self, value: &'a T) -> Result<(), ()> { - let ptr = NonZeroUsize::new(value as *const T as usize).unwrap(); - self.inner.set(ptr) + let ptr = value as *const T as *mut T; + let exchange = + self.inner.compare_exchange(ptr::null_mut(), ptr, Ordering::AcqRel, Ordering::Acquire); + match exchange { + Ok(_) => Ok(()), + Err(_) => Err(()), + } } /// Gets the contents of the cell, initializing it with `f` if the cell was @@ -225,9 +232,11 @@ impl<'a, T> OnceRef<'a, T> { where F: FnOnce() -> &'a T, { - let f = || NonZeroUsize::new(f() as *const T as usize).unwrap(); - let ptr = self.inner.get_or_init(f); - unsafe { &*(ptr.get() as *const T) } + enum Void {} + match self.get_or_try_init(|| Ok::<&'a T, Void>(f())) { + Ok(val) => val, + Err(void) => match void {}, + } } /// Gets the contents of the cell, initializing it with `f` if @@ -241,9 +250,23 @@ impl<'a, T> OnceRef<'a, T> { where F: FnOnce() -> Result<&'a T, E>, { - let f = || f().map(|value| NonZeroUsize::new(value as *const T as usize).unwrap()); - let ptr = self.inner.get_or_try_init(f)?; - unsafe { Ok(&*(ptr.get() as *const T)) } + let mut ptr = self.inner.load(Ordering::Acquire); + + if ptr.is_null() { + // TODO replace with `cast_mut` when MSRV reaches 1.65.0 (also in `set`) + ptr = f()? as *const T as *mut T; + let exchange = self.inner.compare_exchange( + ptr::null_mut(), + ptr, + Ordering::AcqRel, + Ordering::Acquire, + ); + if let Err(old) = exchange { + ptr = old; + } + } + + Ok(unsafe { &*ptr }) } /// ```compile_fail