Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use UnsafeCell<MaybeUninit<T>> in AtomicCell #834

Merged
merged 4 commits into from Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
138 changes: 84 additions & 54 deletions crossbeam-utils/src/atomic/atomic_cell.rs
Expand Up @@ -5,7 +5,7 @@ use crate::primitive::sync::atomic::{self, AtomicBool};
use core::cell::UnsafeCell;
use core::cmp;
use core::fmt;
use core::mem;
use core::mem::{self, ManuallyDrop, MaybeUninit};
use core::sync::atomic::Ordering;

use core::ptr;
Expand All @@ -30,13 +30,20 @@ use super::seq_lock::SeqLock;
/// [`Acquire`]: std::sync::atomic::Ordering::Acquire
/// [`Release`]: std::sync::atomic::Ordering::Release
#[repr(transparent)]
pub struct AtomicCell<T: ?Sized> {
pub struct AtomicCell<T> {
/// The inner value.
///
/// If this value can be transmuted into a primitive atomic type, it will be treated as such.
/// Otherwise, all potentially concurrent operations on this data will be protected by a global
/// lock.
value: UnsafeCell<T>,
///
/// Using MaybeUninit to prevent code outside the cell from observing partially initialized state:
/// <https://github.com/crossbeam-rs/crossbeam/issues/833>
///
/// Note:
/// - we'll never store uninitialized `T` due to our API only using initialized `T`.
/// - this `MaybeUninit` does *not* fix <https://github.com/crossbeam-rs/crossbeam/issues/315>.
value: UnsafeCell<MaybeUninit<T>>,
}

unsafe impl<T: Send> Send for AtomicCell<T> {}
Expand All @@ -59,12 +66,15 @@ impl<T> AtomicCell<T> {
/// ```
pub const fn new(val: T) -> AtomicCell<T> {
AtomicCell {
value: UnsafeCell::new(val),
value: UnsafeCell::new(MaybeUninit::new(val)),
}
}

/// Consumes the atomic and returns the contained value.
///
/// This is safe because passing `self` by value guarantees that no other threads are
/// concurrently accessing the atomic data.
///
/// # Examples
///
/// ```
Expand All @@ -76,7 +86,13 @@ impl<T> AtomicCell<T> {
/// assert_eq!(v, 7);
/// ```
pub fn into_inner(self) -> T {
self.value.into_inner()
let this = ManuallyDrop::new(self);
// SAFETY:
// - passing `self` by value guarantees that no other threads are concurrently
// accessing the atomic data
// - the raw pointer passed in is valid because we got it from an owned value.
// - `ManuallyDrop` prevents double dropping `T`
unsafe { this.as_ptr().read() }
}

/// Returns `true` if operations on values of this type are lock-free.
Expand Down Expand Up @@ -129,7 +145,7 @@ impl<T> AtomicCell<T> {
drop(self.swap(val));
} else {
unsafe {
atomic_store(self.value.get(), val);
atomic_store(self.as_ptr(), val);
}
}
}
Expand All @@ -148,11 +164,9 @@ impl<T> AtomicCell<T> {
/// assert_eq!(a.load(), 8);
/// ```
pub fn swap(&self, val: T) -> T {
unsafe { atomic_swap(self.value.get(), val) }
unsafe { atomic_swap(self.as_ptr(), val) }
}
}

impl<T: ?Sized> AtomicCell<T> {
/// Returns a raw pointer to the underlying data in this atomic cell.
///
/// # Examples
Expand All @@ -166,7 +180,7 @@ impl<T: ?Sized> AtomicCell<T> {
/// ```
#[inline]
pub fn as_ptr(&self) -> *mut T {
self.value.get()
self.value.get() as *mut T
}
}

Expand Down Expand Up @@ -202,7 +216,7 @@ impl<T: Copy> AtomicCell<T> {
/// assert_eq!(a.load(), 7);
/// ```
pub fn load(&self) -> T {
unsafe { atomic_load(self.value.get()) }
unsafe { atomic_load(self.as_ptr()) }
}
}

Expand Down Expand Up @@ -254,7 +268,7 @@ impl<T: Copy + Eq> AtomicCell<T> {
/// assert_eq!(a.load(), 2);
/// ```
pub fn compare_exchange(&self, current: T, new: T) -> Result<T, T> {
unsafe { atomic_compare_exchange_weak(self.value.get(), current, new) }
unsafe { atomic_compare_exchange_weak(self.as_ptr(), current, new) }
}

/// Fetches the value, and applies a function to it that returns an optional
Expand Down Expand Up @@ -292,6 +306,22 @@ impl<T: Copy + Eq> AtomicCell<T> {
}
}

// `MaybeUninit` prevents `T` from being dropped, so we need to implement `Drop`
// for `AtomicCell` to avoid leaks of non-`Copy` types.
impl<T> Drop for AtomicCell<T> {
fn drop(&mut self) {
if mem::needs_drop::<T>() {
// SAFETY:
// - the mutable reference guarantees that no other threads are concurrently accessing the atomic data
// - the raw pointer passed in is valid because we got it from a reference
// - `MaybeUninit` prevents double dropping `T`
unsafe {
self.as_ptr().drop_in_place();
}
}
}
}

macro_rules! impl_arithmetic {
($t:ty, fallback, $example:tt) => {
impl AtomicCell<$t> {
Expand All @@ -311,8 +341,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -334,8 +364,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -355,8 +385,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -376,8 +406,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -397,8 +427,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -418,8 +448,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -440,8 +470,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_max(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -462,8 +492,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_min(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand All @@ -489,11 +519,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_add(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -517,11 +547,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_sub(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -543,11 +573,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_and(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -569,11 +599,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_nand(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -595,11 +625,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_or(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -621,11 +651,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_xor(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -651,8 +681,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_max requires Rust 1.45.
self.fetch_update(|old| Some(cmp::max(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -678,8 +708,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_min requires Rust 1.45.
self.fetch_update(|old| Some(cmp::min(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand Down Expand Up @@ -738,7 +768,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_and(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_and(val, Ordering::AcqRel)
}

Expand All @@ -762,7 +792,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_nand(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_nand(val, Ordering::AcqRel)
}

Expand All @@ -783,7 +813,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_or(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_or(val, Ordering::AcqRel)
}

Expand All @@ -804,7 +834,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_xor(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_xor(val, Ordering::AcqRel)
}
}
Expand Down