Skip to content

Commit

Permalink
Use AtomicPtr for race::OnceRef to avoid ptr-int-ptr casts (keeping
Browse files Browse the repository at this point in the history
provenance intact)
  • Loading branch information
Imberflur committed Feb 14, 2023
1 parent d706539 commit 6e351e5
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions src/race.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ use atomic_polyfill as atomic;
#[cfg(not(feature = "critical-section"))]
use core::sync::atomic;

use atomic::{AtomicUsize, Ordering};
use atomic::{AtomicPtr, AtomicUsize, Ordering};
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::num::NonZeroUsize;
use core::ptr;

/// A thread-safe cell which can be written to only once.
#[derive(Default, Debug)]
Expand Down Expand Up @@ -176,7 +177,7 @@ impl OnceBool {

/// A thread-safe cell which can be written to only once.
pub struct OnceRef<'a, T> {
inner: OnceNonZeroUsize,
inner: AtomicPtr<T>,
ghost: PhantomData<UnsafeCell<&'a T>>,
}

Expand All @@ -198,21 +199,27 @@ impl<'a, T> Default for OnceRef<'a, T> {
impl<'a, T> OnceRef<'a, T> {
/// Creates a new empty cell.
pub const fn new() -> OnceRef<'a, T> {
OnceRef { inner: OnceNonZeroUsize::new(), ghost: PhantomData }
OnceRef { inner: AtomicPtr::new(ptr::null_mut()), ghost: PhantomData }
}

/// Gets a reference to the underlying value.
pub fn get(&self) -> Option<&'a T> {
self.inner.get().map(|ptr| unsafe { &*(ptr.get() as *const T) })
let ptr = self.inner.load(Ordering::Acquire);
unsafe { ptr.as_ref() }
}

/// Sets the contents of this cell to `value`.
///
/// Returns `Ok(())` if the cell was empty and `Err(value)` if it was
/// full.
pub fn set(&self, value: &'a T) -> Result<(), ()> {
let ptr = NonZeroUsize::new(value as *const T as usize).unwrap();
self.inner.set(ptr)
let ptr = value as *const T as *mut T;
let exchange =
self.inner.compare_exchange(ptr::null_mut(), ptr, Ordering::AcqRel, Ordering::Acquire);
match exchange {
Ok(_) => Ok(()),
Err(_) => Err(()),
}
}

/// Gets the contents of the cell, initializing it with `f` if the cell was
Expand All @@ -225,9 +232,11 @@ impl<'a, T> OnceRef<'a, T> {
where
F: FnOnce() -> &'a T,
{
let f = || NonZeroUsize::new(f() as *const T as usize).unwrap();
let ptr = self.inner.get_or_init(f);
unsafe { &*(ptr.get() as *const T) }
enum Void {}
match self.get_or_try_init(|| Ok::<&'a T, Void>(f())) {
Ok(val) => val,
Err(void) => match void {},
}
}

/// Gets the contents of the cell, initializing it with `f` if
Expand All @@ -241,9 +250,23 @@ impl<'a, T> OnceRef<'a, T> {
where
F: FnOnce() -> Result<&'a T, E>,
{
let f = || f().map(|value| NonZeroUsize::new(value as *const T as usize).unwrap());
let ptr = self.inner.get_or_try_init(f)?;
unsafe { Ok(&*(ptr.get() as *const T)) }
let mut ptr = self.inner.load(Ordering::Acquire);

if ptr.is_null() {
// TODO replace with `cast_mut` when MSRV reaches 1.65.0 (also in `set`)
ptr = f()? as *const T as *mut T;
let exchange = self.inner.compare_exchange(
ptr::null_mut(),
ptr,
Ordering::AcqRel,
Ordering::Acquire,
);
if let Err(old) = exchange {
ptr = old;
}
}

Ok(unsafe { &*ptr })
}

/// ```compile_fail
Expand Down

0 comments on commit 6e351e5

Please sign in to comment.