Skip to content

Commit

Permalink
Keep provenance intact by avoiding ptr-int-ptr
Browse files Browse the repository at this point in the history
once_cell is an extremely widely-used crate, so it would be very nice if
it conformed to the stricted/simplest/checkable model we have for
aliasing in Rust. To do this, we need to avoid creating a pointer from
an integer by cast or transmute. Pointers created this way can be valid,
but are a nightmare for a checker like Miri. The situation is generally
explained by these docs: https://doc.rust-lang.org/nightly/std/ptr/fn.from_exposed_addr.html

This implementation is mostly gross because all the APIs that would make
this ergonomic are behind #![feature(strict_provenance)]
  • Loading branch information
saethlin committed Jun 28, 2022
1 parent 2cdfc1e commit 16b6312
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions src/imp_std.rs
Expand Up @@ -8,7 +8,7 @@ use std::{
hint::unreachable_unchecked,
marker::PhantomData,
panic::{RefUnwindSafe, UnwindSafe},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
sync::atomic::{AtomicBool, AtomicPtr, Ordering},
thread::{self, Thread},
};

Expand All @@ -24,7 +24,7 @@ pub(crate) struct OnceCell<T> {
//
// State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
// allow waiters.
queue: AtomicUsize,
queue: AtomicPtr<Waiter>,
_marker: PhantomData<*mut Waiter>,
value: UnsafeCell<Option<T>>,
}
Expand All @@ -43,15 +43,15 @@ impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
impl<T> OnceCell<T> {
pub(crate) const fn new() -> OnceCell<T> {
OnceCell {
queue: AtomicUsize::new(INCOMPLETE),
queue: AtomicPtr::new(INCOMPLETE_PTR),
_marker: PhantomData,
value: UnsafeCell::new(None),
}
}

pub(crate) const fn with_value(value: T) -> OnceCell<T> {
OnceCell {
queue: AtomicUsize::new(COMPLETE),
queue: AtomicPtr::new(COMPLETE_PTR),
_marker: PhantomData,
value: UnsafeCell::new(Some(value)),
}
Expand All @@ -64,7 +64,7 @@ impl<T> OnceCell<T> {
// operations visible to us, and, this being a fast path, weaker
// ordering helps with performance. This `Acquire` synchronizes with
// `SeqCst` operations on the slow path.
self.queue.load(Ordering::Acquire) == COMPLETE
self.queue.load(Ordering::Acquire) == COMPLETE_PTR
}

/// Safety: synchronizes with store to value via SeqCst read from state,
Expand Down Expand Up @@ -145,6 +145,9 @@ impl<T> OnceCell<T> {
const INCOMPLETE: usize = 0x0;
const RUNNING: usize = 0x1;
const COMPLETE: usize = 0x2;
const INCOMPLETE_PTR: *mut Waiter = core::ptr::null_mut();
//const RUNNING_PTR: *mut Waiter = core::ptr::null_mut::<Waiter>().wrapping_add(0x1);
const COMPLETE_PTR: *mut Waiter = core::ptr::null_mut::<u8>().wrapping_add(0x2).cast();

// Mask to learn about the state. All other bits are the queue of waiters if
// this is in the RUNNING state.
Expand All @@ -156,23 +159,24 @@ const STATE_MASK: usize = 0x3;
struct Waiter {
thread: Cell<Option<Thread>>,
signaled: AtomicBool,
next: *const Waiter,
next: *mut Waiter,
}

/// Drains and notifies the queue of waiters on drop.
struct Guard<'a> {
queue: &'a AtomicUsize,
new_queue: usize,
queue: &'a AtomicPtr<Waiter>,
new_queue: *mut Waiter,
}

impl Drop for Guard<'_> {
fn drop(&mut self) {
let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);

assert_eq!(queue & STATE_MASK, RUNNING);
let state = queue as usize & STATE_MASK;
assert_eq!(state, RUNNING as usize);

unsafe {
let mut waiter = (queue & !STATE_MASK) as *const Waiter;
let mut waiter = queue.cast::<u8>().wrapping_sub(state).cast::<Waiter>();
while !waiter.is_null() {
let next = (*waiter).next;
let thread = (*waiter).thread.take().unwrap();
Expand All @@ -191,27 +195,27 @@ impl Drop for Guard<'_> {
//
// Note: this is intentionally monomorphic
#[inline(never)]
fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() -> bool>) {
fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
let mut curr_queue = queue.load(Ordering::Acquire);

loop {
let curr_state = curr_queue & STATE_MASK;
let curr_state = curr_queue as usize & STATE_MASK;
match (curr_state, &mut init) {
(COMPLETE, _) => return,
(INCOMPLETE, Some(init)) => {
let exchange = queue.compare_exchange(
curr_queue,
(curr_queue & !STATE_MASK) | RUNNING,
curr_queue.cast::<u8>().wrapping_sub(curr_state).wrapping_add(RUNNING).cast(),
Ordering::Acquire,
Ordering::Acquire,
);
if let Err(new_queue) = exchange {
curr_queue = new_queue;
continue;
}
let mut guard = Guard { queue, new_queue: INCOMPLETE };
let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
if init() {
guard.new_queue = COMPLETE;
guard.new_queue = COMPLETE_PTR;
}
return;
}
Expand All @@ -224,24 +228,24 @@ fn initialize_or_wait(queue: &AtomicUsize, mut init: Option<&mut dyn FnMut() ->
}
}

fn wait(queue: &AtomicUsize, mut curr_queue: usize) {
let curr_state = curr_queue & STATE_MASK;
fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
let curr_state = curr_queue as usize & STATE_MASK;
loop {
let node = Waiter {
thread: Cell::new(Some(thread::current())),
signaled: AtomicBool::new(false),
next: (curr_queue & !STATE_MASK) as *const Waiter,
next: curr_queue.cast::<u8>().wrapping_sub(curr_state).cast(),
};
let me = &node as *const Waiter as usize;
let me = &node as *const Waiter as *mut Waiter;

let exchange = queue.compare_exchange(
curr_queue,
me | curr_state,
me.cast::<u8>().wrapping_add(curr_state).cast(),
Ordering::Release,
Ordering::Relaxed,
);
if let Err(new_queue) = exchange {
if new_queue & STATE_MASK != curr_state {
if new_queue as usize & STATE_MASK != curr_state {
return;
}
curr_queue = new_queue;
Expand Down

0 comments on commit 16b6312

Please sign in to comment.