Skip to content

Commit

Permalink
Use UnsafeCell<MaybeUninit<T>> in AtomicCell
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed May 24, 2022
1 parent 1176693 commit d670b12
Showing 1 changed file with 56 additions and 54 deletions.
110 changes: 56 additions & 54 deletions crossbeam-utils/src/atomic/atomic_cell.rs
Expand Up @@ -5,7 +5,7 @@ use crate::primitive::sync::atomic::{self, AtomicBool};
use core::cell::UnsafeCell;
use core::cmp;
use core::fmt;
use core::mem;
use core::mem::{self, MaybeUninit};
use core::sync::atomic::Ordering;

use core::ptr;
Expand All @@ -30,13 +30,16 @@ use super::seq_lock::SeqLock;
/// [`Acquire`]: std::sync::atomic::Ordering::Acquire
/// [`Release`]: std::sync::atomic::Ordering::Release
#[repr(transparent)]
pub struct AtomicCell<T: ?Sized> {
pub struct AtomicCell<T> {
/// The inner value.
///
/// If this value can be transmuted into a primitive atomic type, it will be treated as such.
/// Otherwise, all potentially concurrent operations on this data will be protected by a global
/// lock.
value: UnsafeCell<T>,
///
/// Using MaybeUninit to prevent code outside the cell from observing partially initialized state:
/// <https://github.com/crossbeam-rs/crossbeam/issues/833>
value: UnsafeCell<MaybeUninit<T>>,
}

unsafe impl<T: Send> Send for AtomicCell<T> {}
Expand All @@ -59,7 +62,7 @@ impl<T> AtomicCell<T> {
/// ```
pub const fn new(val: T) -> AtomicCell<T> {
AtomicCell {
value: UnsafeCell::new(val),
value: UnsafeCell::new(MaybeUninit::new(val)),
}
}

Expand All @@ -76,7 +79,8 @@ impl<T> AtomicCell<T> {
/// assert_eq!(v, 7);
/// ```
pub fn into_inner(self) -> T {
self.value.into_inner()
// SAFETY: we'll never store uninitialized `T`
unsafe { self.value.into_inner().assume_init() }
}

/// Returns `true` if operations on values of this type are lock-free.
Expand Down Expand Up @@ -129,7 +133,7 @@ impl<T> AtomicCell<T> {
drop(self.swap(val));
} else {
unsafe {
atomic_store(self.value.get(), val);
atomic_store(self.as_ptr(), val);
}
}
}
Expand All @@ -148,11 +152,9 @@ impl<T> AtomicCell<T> {
/// assert_eq!(a.load(), 8);
/// ```
pub fn swap(&self, val: T) -> T {
unsafe { atomic_swap(self.value.get(), val) }
unsafe { atomic_swap(self.as_ptr(), val) }
}
}

impl<T: ?Sized> AtomicCell<T> {
/// Returns a raw pointer to the underlying data in this atomic cell.
///
/// # Examples
Expand All @@ -166,7 +168,7 @@ impl<T: ?Sized> AtomicCell<T> {
/// ```
#[inline]
pub fn as_ptr(&self) -> *mut T {
self.value.get()
self.value.get() as *mut T
}
}

Expand Down Expand Up @@ -202,7 +204,7 @@ impl<T: Copy> AtomicCell<T> {
/// assert_eq!(a.load(), 7);
/// ```
pub fn load(&self) -> T {
unsafe { atomic_load(self.value.get()) }
unsafe { atomic_load(self.as_ptr()) }
}
}

Expand Down Expand Up @@ -254,7 +256,7 @@ impl<T: Copy + Eq> AtomicCell<T> {
/// assert_eq!(a.load(), 2);
/// ```
pub fn compare_exchange(&self, current: T, new: T) -> Result<T, T> {
unsafe { atomic_compare_exchange_weak(self.value.get(), current, new) }
unsafe { atomic_compare_exchange_weak(self.as_ptr(), current, new) }
}

/// Fetches the value, and applies a function to it that returns an optional
Expand Down Expand Up @@ -311,8 +313,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -334,8 +336,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -355,8 +357,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -376,8 +378,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -397,8 +399,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -418,8 +420,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -440,8 +442,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_max(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -462,8 +464,8 @@ macro_rules! impl_arithmetic {
/// ```
#[inline]
pub fn fetch_min(&self, val: $t) -> $t {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand All @@ -489,11 +491,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_add(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_add(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_add(val);
old
Expand All @@ -517,11 +519,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_sub(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_sub(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = value.wrapping_sub(val);
old
Expand All @@ -543,11 +545,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_and(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_and(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value &= val;
old
Expand All @@ -569,11 +571,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_nand(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_nand(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = !(old & val);
old
Expand All @@ -595,11 +597,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_or(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_or(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value |= val;
old
Expand All @@ -621,11 +623,11 @@ macro_rules! impl_arithmetic {
#[inline]
pub fn fetch_xor(&self, val: $t) -> $t {
if can_transmute::<$t, $atomic>() {
let a = unsafe { &*(self.value.get() as *const $atomic) };
let a = unsafe { &*(self.as_ptr() as *const $atomic) };
a.fetch_xor(val, Ordering::AcqRel)
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value ^= val;
old
Expand All @@ -651,8 +653,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_max requires Rust 1.45.
self.fetch_update(|old| Some(cmp::max(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::max(old, val);
old
Expand All @@ -678,8 +680,8 @@ macro_rules! impl_arithmetic {
// TODO: Atomic*::fetch_min requires Rust 1.45.
self.fetch_update(|old| Some(cmp::min(old, val))).unwrap()
} else {
let _guard = lock(self.value.get() as usize).write();
let value = unsafe { &mut *(self.value.get()) };
let _guard = lock(self.as_ptr() as usize).write();
let value = unsafe { &mut *(self.as_ptr()) };
let old = *value;
*value = cmp::min(old, val);
old
Expand Down Expand Up @@ -738,7 +740,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_and(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_and(val, Ordering::AcqRel)
}

Expand All @@ -762,7 +764,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_nand(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_nand(val, Ordering::AcqRel)
}

Expand All @@ -783,7 +785,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_or(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_or(val, Ordering::AcqRel)
}

Expand All @@ -804,7 +806,7 @@ impl AtomicCell<bool> {
/// ```
#[inline]
pub fn fetch_xor(&self, val: bool) -> bool {
let a = unsafe { &*(self.value.get() as *const AtomicBool) };
let a = unsafe { &*(self.as_ptr() as *const AtomicBool) };
a.fetch_xor(val, Ordering::AcqRel)
}
}
Expand Down

0 comments on commit d670b12

Please sign in to comment.