Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AtomicPtr for race::OnceRef to avoid ptr-int-ptr #219

Merged
merged 1 commit into from Feb 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 35 additions & 12 deletions src/race.rs
Expand Up @@ -24,10 +24,11 @@ use atomic_polyfill as atomic;
#[cfg(not(feature = "critical-section"))]
use core::sync::atomic;

use atomic::{AtomicUsize, Ordering};
use atomic::{AtomicPtr, AtomicUsize, Ordering};
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::num::NonZeroUsize;
use core::ptr;

/// A thread-safe cell which can be written to only once.
#[derive(Default, Debug)]
Expand Down Expand Up @@ -176,7 +177,7 @@ impl OnceBool {

/// A thread-safe cell which can be written to only once.
pub struct OnceRef<'a, T> {
inner: OnceNonZeroUsize,
inner: AtomicPtr<T>,
ghost: PhantomData<UnsafeCell<&'a T>>,
}

Expand All @@ -198,21 +199,27 @@ impl<'a, T> Default for OnceRef<'a, T> {
impl<'a, T> OnceRef<'a, T> {
/// Creates a new empty cell.
pub const fn new() -> OnceRef<'a, T> {
OnceRef { inner: OnceNonZeroUsize::new(), ghost: PhantomData }
OnceRef { inner: AtomicPtr::new(ptr::null_mut()), ghost: PhantomData }
}

/// Gets a reference to the underlying value.
pub fn get(&self) -> Option<&'a T> {
self.inner.get().map(|ptr| unsafe { &*(ptr.get() as *const T) })
let ptr = self.inner.load(Ordering::Acquire);
unsafe { ptr.as_ref() }
}

/// Sets the contents of this cell to `value`.
///
/// Returns `Ok(())` if the cell was empty and `Err(value)` if it was
/// full.
pub fn set(&self, value: &'a T) -> Result<(), ()> {
let ptr = NonZeroUsize::new(value as *const T as usize).unwrap();
self.inner.set(ptr)
let ptr = value as *const T as *mut T;
let exchange =
self.inner.compare_exchange(ptr::null_mut(), ptr, Ordering::AcqRel, Ordering::Acquire);
match exchange {
Ok(_) => Ok(()),
Err(_) => Err(()),
}
}

/// Gets the contents of the cell, initializing it with `f` if the cell was
Expand All @@ -225,9 +232,11 @@ impl<'a, T> OnceRef<'a, T> {
where
F: FnOnce() -> &'a T,
{
let f = || NonZeroUsize::new(f() as *const T as usize).unwrap();
let ptr = self.inner.get_or_init(f);
unsafe { &*(ptr.get() as *const T) }
enum Void {}
match self.get_or_try_init(|| Ok::<&'a T, Void>(f())) {
Ok(val) => val,
Err(void) => match void {},
}
}

/// Gets the contents of the cell, initializing it with `f` if
Expand All @@ -241,9 +250,23 @@ impl<'a, T> OnceRef<'a, T> {
where
F: FnOnce() -> Result<&'a T, E>,
{
let f = || f().map(|value| NonZeroUsize::new(value as *const T as usize).unwrap());
let ptr = self.inner.get_or_try_init(f)?;
unsafe { Ok(&*(ptr.get() as *const T)) }
let mut ptr = self.inner.load(Ordering::Acquire);

if ptr.is_null() {
// TODO replace with `cast_mut` when MSRV reaches 1.65.0 (also in `set`)
ptr = f()? as *const T as *mut T;
let exchange = self.inner.compare_exchange(
ptr::null_mut(),
ptr,
Ordering::AcqRel,
Ordering::Acquire,
);
if let Err(old) = exchange {
ptr = old;
}
}

Ok(unsafe { &*ptr })
}

/// ```compile_fail
Expand Down