Skip to content

Commit

Permalink
Merge #834
Browse files Browse the repository at this point in the history
834: Use `UnsafeCell<MaybeUninit<T>>` in AtomicCell r=taiki-e a=taiki-e

Fixes #833

Note: This contains two breaking changes:
- Unsized values are no longer allowed.
  This is because `MaybeUninit` doesn't allow it.
- `AtomicCell` now implements `Drop`.
  This is because `MaybeUninit` prevents `T` from being dropped, so we need to implement `Drop` for `AtomicCell` to avoid leaks of non-`Copy` types.

Breakages are allowed because this is a soundness bug fix. However, given the amount of breakage, we would not be able to yank the affected releases and would only create an advisory.

Co-authored-by: Taiki Endo <te316e89@gmail.com>
  • Loading branch information
bors[bot] and taiki-e committed Jun 21, 2022
2 parents 23400ce + 012b0c2 commit 9e9ff76
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 54 deletions.
138 changes: 84 additions & 54 deletions crossbeam-utils/src/atomic/atomic_cell.rs
Expand Up @@ -5,7 +5,7 @@ use crate::primitive::sync::atomic::{self, AtomicBool};
use core::cell::UnsafeCell;
use core::cmp;
use core::fmt;
use core::mem;
use core::mem::{self, ManuallyDrop, MaybeUninit};
use core::sync::atomic::Ordering;

use core::ptr;
Expand All @@ -30,13 +30,20 @@ use super::seq_lock::SeqLock;
/// [`Acquire`]: std::sync::atomic::Ordering::Acquire
/// [`Release`]: std::sync::atomic::Ordering::Release
#[repr(transparent)]
pub struct AtomicCell<T: ?Sized> {
pub struct AtomicCell<T> {
/// The inner value.
///
/// If this value can be transmuted into a primitive atomic type, it will be treated as such.
/// Otherwise, all potentially concurrent operations on this data will be protected by a global
/// lock.
value: UnsafeCell<T>,
///
/// Using MaybeUninit to prevent code outside the cell from observing partially initialized state:
/// <https://github.com/crossbeam-rs/crossbeam/issues/833>
///
/// Note:
/// - we'll never store uninitialized `T` due to our API only using initialized `T`.
/// - this `MaybeUninit` does *not* fix <https://github.com/crossbeam-rs/crossbeam/issues/315>.
value: UnsafeCell<MaybeUninit<T>>,
}

unsafe impl<T: Send> Send for AtomicCell<T> {}
Expand All @@ -59,12 +66,15 @@ impl<T> AtomicCell<T> {
/// ```
pub const fn new(val: T) -> AtomicCell<T> {
AtomicCell {
value: UnsafeCell::new(val),
value: UnsafeCell::new(MaybeUninit::new(val)),
}
}

/// Consumes the atomic and returns the contained value.
///
/// This is safe because passing `self` by value guarantees that no other threads are
/// concurrently accessing the atomic data.
///
/// # Examples
///
/// ```
Expand All @@ -76,7 +86,13 @@ impl<T> AtomicCell<T> {
/// assert_eq!(v, 7);
/// ```
pub fn into_inner(self) -> T {
self.value.into_inner()
let this = ManuallyDrop::new(self);
// SAFETY:
// - passing `self` by value guarantees that no other threads are concurrently
// accessing the atomic data
// - the raw pointer passed in is valid because we got it from an owned value.
// - `ManuallyDrop` prevents double dropping `T`
unsafe { this.as_ptr().read() }
}

/// Returns `true` if operations on values of this type are lock-free.
Expand Down Expand Up @@ -129,7 +145,7 @@ impl<T> AtomicCell<T> {
drop(self.swap(val));
} else {
unsafe {
atomic_store(self.value.get(), val);
atomic_store(self.as_ptr(), val);
}
}
}
Expand All @@ -148,11 +164,9 @@ impl<T> AtomicCell<T> {
/// assert_eq!(a.load(), 8);
/// ```
pub fn swap(&self, val: T) -> T {
unsafe { atomic_swap(self.value.get(), val) }
unsafe { atomic_swap(self.as_ptr(), val) }
}
}

impl<T: ?Sized> AtomicCell<T> {
/// Returns a raw pointer to the underlying data in this atomic cell.
///
/// # Examples
Expand All @@ -166,7 +180,7 @@ impl<T: ?Sized> AtomicCell<T> {
/// ```
#[inline]
pub fn as_ptr(&self) -> *mut T {
self.value.get()
self.value.get() as *mut T
}
}

Expand Down Expand Up @@ -202,7 +216,7 @@ impl<T: Copy> AtomicCell<T> {
/// assert_eq!(a.load(), 7);
/// ```
pub fn load(&self) -> T {
unsafe { atomic_load(self.value.get()) }
unsafe { atomic_load(self.as_ptr()) }
}
}

Expand Down Expand Up @@ -254,7 +268,7 @@ impl<T: Copy + Eq> AtomicCell<T> {
/// assert_eq!(a.load(), 2);
/// ```
pub fn compare_exchange(&self, current: T, new: T) -> Result<T, T> {
unsafe { atomic_compare_exchange_weak(self.value.get(), current, new) }
unsafe { atomic_compare_exchange_weak(self.as_ptr(), current, new) }
}

/// Fetches the value, and applies a function to it that returns an optional
Expand Down Expand Up @@ -292,6 +306,22 @@ impl<T: Copy + Eq> AtomicCell<T> {
}
}

// `MaybeUninit` prevents `T` from being dropped, so we need to implement `Drop`
// for `AtomicCell` to avoid leaks of non-`Copy` types.
impl<T> Drop for AtomicCell<T> {
fn drop(&mut self) {
if mem::needs_drop::<T>() {
// SAFETY:
// - the mutable reference guarantees that no other threads are concurrently accessing the atomic data
// - the raw pointer passed in is valid because we got it from a reference
// - `MaybeUninit` prevents double dropping `T`
unsafe {
self.as_ptr().drop_in_place();
}
}
}
}

macro_rules! impl_arithmetic {
($t:ty, fallback, $example:tt) => {
impl AtomicCell<$t> {
Expand All @@ -311,8 +341,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -334,8 +364,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -355,8 +385,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -376,8 +406,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -397,8 +427,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -418,8 +448,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -440,8 +470,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_max(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -462,8 +492,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_min(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand All @@ -489,11 +519,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_add(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -517,11 +547,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_sub(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -543,11 +573,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_and(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -569,11 +599,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_nand(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -595,11 +625,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_or(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -621,11 +651,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_xor(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -651,8 +681,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_max requires Rust 1.45.
self.fetch_update(|old| Some(cmp::max(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -678,8 +708,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_min requires Rust 1.45.
self.fetch_update(|old| Some(cmp::min(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand Down Expand Up @@ -738,7 +768,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_and(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_and(val, Ordering::AcqRel)
}

Expand All @@ -762,7 +792,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_nand(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_nand(val, Ordering::AcqRel)
}

Expand All @@ -783,7 +813,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_or(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_or(val, Ordering::AcqRel)
}

Expand All @@ -804,7 +834,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_xor(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_xor(val, Ordering::AcqRel)
}
}
Expand Down

0 comments on commit 9e9ff76

Please sign in to comment.