Skip to content

Commit

Permalink
Make {Mutex, Condvar, RwLock}::new const
Browse files Browse the repository at this point in the history
These constructors have been const since Rust 1.63
(rust-lang/rust#93740). It's pretty easy for
us to make them const too, which allows code that relies on them being
const to correctly compile with Shuttle.

The one exception is that HashMap::new isn't const, and our Condvar
implementation uses a HashMap to track waiters. I took the easy way out
and just used a vector as an association list instead -- we shouldn't
expect large numbers of waiters on the same condvar, so this shouldn't
be too inefficient.
  • Loading branch information
jamesbornholt committed May 19, 2023
1 parent b9205aa commit dcdac00
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 44 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ keywords = ["concurrency", "lock", "thread", "async"]
categories = ["asynchronous", "concurrency", "development-tools::testing"]

[dependencies]
assoc = "0.1.3"
bitvec = "1.0.1"
generator = "0.7.1"
hex = "0.4.2"
Expand All @@ -17,7 +18,7 @@ rand_core = "0.6.4"
rand = "0.8.5"
rand_pcg = "0.3.1"
scoped-tls = "1.0.0"
smallvec = "1.6.1"
smallvec = { version = "1.10.0", features = ["const_new"] }
tracing = { version = "0.1.21", default-features = false, features = ["std"] }

[dev-dependencies]
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/task/clock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ pub struct VectorClock {
}

impl VectorClock {
pub(crate) fn new() -> Self {
Self { time: SmallVec::new() }
pub(crate) const fn new() -> Self {
Self {
time: SmallVec::new_const(),
}
}

#[cfg(test)]
Expand Down
8 changes: 3 additions & 5 deletions src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,8 @@ pub(crate) struct TaskSet {
}

impl TaskSet {
pub fn new() -> Self {
Self {
tasks: BitVec::from_bitslice(bits![0; DEFAULT_INLINE_TASKS]),
}
pub const fn new() -> Self {
Self { tasks: BitVec::EMPTY }
}

pub fn contains(&self, tid: TaskId) -> bool {
Expand All @@ -446,7 +444,7 @@ impl TaskSet {
/// the set did have this value present, `false` is returned.
pub fn insert(&mut self, tid: TaskId) -> bool {
if tid.0 >= self.tasks.len() {
self.tasks.resize(1 + tid.0, false);
self.tasks.resize(DEFAULT_INLINE_TASKS.max(1 + tid.0), false);
}
!std::mem::replace(&mut *self.tasks.get_mut(tid.0).unwrap(), true)
}
Expand Down
20 changes: 11 additions & 9 deletions src/sync/condvar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use crate::sync::MutexGuard;
use assoc::AssocExt;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::rc::Rc;
use std::collections::VecDeque;
use std::sync::{LockResult, PoisonError};
use std::time::Duration;
use tracing::trace;
Expand All @@ -15,12 +15,13 @@ use tracing::trace;
/// waiting for an event to occur.
#[derive(Debug)]
pub struct Condvar {
state: Rc<RefCell<CondvarState>>,
state: RefCell<CondvarState>,
}

#[derive(Debug)]
struct CondvarState {
waiters: HashMap<TaskId, CondvarWaitStatus>,
// TODO: this should be a HashMap but [HashMap::new] is not const
waiters: Vec<(TaskId, CondvarWaitStatus)>,
next_epoch: usize,
}

Expand Down Expand Up @@ -114,14 +115,14 @@ enum CondvarWaitStatus {
// and can run in any order (because they are all contending on the same mutex).
impl Condvar {
/// Creates a new condition variable which is ready to be waited on and notified.
pub fn new() -> Self {
pub const fn new() -> Self {
let state = CondvarState {
waiters: HashMap::new(),
waiters: Vec::new(),
next_epoch: 0,
};

Self {
state: Rc::new(RefCell::new(state)),
state: RefCell::new(state),
}
}

Expand All @@ -133,7 +134,8 @@ impl Condvar {

trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "waiting on condvar {:p}", self);

assert!(state.waiters.insert(me, CondvarWaitStatus::Waiting).is_none());
debug_assert!(<_ as AssocExt<_, _>>::get(&state.waiters, &me).is_none());
state.waiters.push((me, CondvarWaitStatus::Waiting));
// TODO: Condvar::wait should allow for spurious wakeups.
ExecutionState::with(|s| s.current_mut().block(false));
drop(state);
Expand All @@ -144,7 +146,7 @@ impl Condvar {
// After the context switch, consume whichever signal that woke this thread
let mut state = self.state.borrow_mut();
trace!(waiters=?state.waiters, next_epoch=state.next_epoch, "woken from condvar {:p}", self);
let my_status = state.waiters.remove(&me).expect("should be waiting");
let my_status = <_ as AssocExt<_, _>>::remove(&mut state.waiters, &me).expect("should be waiting");
match my_status {
CondvarWaitStatus::Broadcast(clock) => {
// Woken by a broadcast, so nothing to do except update the clock
Expand Down
7 changes: 3 additions & 4 deletions src/sync/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::ops::{Deref, DerefMut};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::{LockResult, PoisonError, TryLockError, TryLockResult};
use tracing::trace;

/// A mutex, the same as [`std::sync::Mutex`].
pub struct Mutex<T: ?Sized> {
state: Rc<RefCell<MutexState>>,
state: RefCell<MutexState>,
inner: std::sync::Mutex<T>,
}

Expand All @@ -31,7 +30,7 @@ struct MutexState {

impl<T> Mutex<T> {
/// Creates a new mutex in an unlocked state ready for use.
pub fn new(value: T) -> Self {
pub const fn new(value: T) -> Self {
let state = MutexState {
holder: None,
waiters: TaskSet::new(),
Expand All @@ -40,7 +39,7 @@ impl<T> Mutex<T> {

Self {
inner: std::sync::Mutex::new(value),
state: Rc::new(RefCell::new(state)),
state: RefCell::new(state),
}
}
}
Expand Down
45 changes: 22 additions & 23 deletions src/sync/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::ops::{Deref, DerefMut};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::{LockResult, PoisonError, TryLockError, TryLockResult};
use tracing::trace;

Expand All @@ -16,7 +15,7 @@ use tracing::trace;
/// `RwLock` more than once. The `std` version is ambiguous about what behavior is allowed here, so
/// we choose the most conservative one.
pub struct RwLock<T: ?Sized> {
state: Rc<RefCell<RwLockState>>,
state: RefCell<RwLockState>,
inner: std::sync::RwLock<T>,
}

Expand All @@ -43,7 +42,7 @@ enum RwLockType {

impl<T> RwLock<T> {
/// Create a new instance of an `RwLock<T>` which is unlocked.
pub fn new(value: T) -> Self {
pub const fn new(value: T) -> Self {
let state = RwLockState {
holder: RwLockHolder::None,
waiting_readers: TaskSet::new(),
Expand All @@ -53,7 +52,7 @@ impl<T> RwLock<T> {

Self {
inner: std::sync::RwLock::new(value),
state: Rc::new(RefCell::new(state)),
state: RefCell::new(state),
}
}
}
Expand All @@ -67,12 +66,12 @@ impl<T: ?Sized> RwLock<T> {
match self.inner.try_read() {
Ok(guard) => Ok(RwLockReadGuard {
inner: Some(guard),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}),
Err(TryLockError::Poisoned(err)) => Err(PoisonError::new(RwLockReadGuard {
inner: Some(err.into_inner()),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
})),
Err(TryLockError::WouldBlock) => panic!("rwlock state out of sync"),
Expand All @@ -87,12 +86,12 @@ impl<T: ?Sized> RwLock<T> {
match self.inner.try_write() {
Ok(guard) => Ok(RwLockWriteGuard {
inner: Some(guard),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}),
Err(TryLockError::Poisoned(err)) => Err(PoisonError::new(RwLockWriteGuard {
inner: Some(err.into_inner()),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
})),
Err(TryLockError::WouldBlock) => panic!("rwlock state out of sync"),
Expand All @@ -111,12 +110,12 @@ impl<T: ?Sized> RwLock<T> {
match self.inner.try_read() {
Ok(guard) => Ok(RwLockReadGuard {
inner: Some(guard),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}),
Err(TryLockError::Poisoned(err)) => Err(TryLockError::Poisoned(PoisonError::new(RwLockReadGuard {
inner: Some(err.into_inner()),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}))),
Err(TryLockError::WouldBlock) => panic!("rwlock state out of sync"),
Expand All @@ -135,12 +134,12 @@ impl<T: ?Sized> RwLock<T> {
match self.inner.try_write() {
Ok(guard) => Ok(RwLockWriteGuard {
inner: Some(guard),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}),
Err(TryLockError::Poisoned(err)) => Err(TryLockError::Poisoned(PoisonError::new(RwLockWriteGuard {
inner: Some(err.into_inner()),
state: Rc::clone(&self.state),
rwlock: self,
me: ExecutionState::me(),
}))),
Err(TryLockError::WouldBlock) => panic!("rwlock state out of sync"),
Expand Down Expand Up @@ -175,7 +174,7 @@ impl<T: ?Sized> RwLock<T> {
waiting_writers = ?state.waiting_writers,
"acquiring {:?} lock on rwlock {:p}",
typ,
self.state,
self,
);

// We are waiting for the lock
Expand Down Expand Up @@ -250,7 +249,7 @@ impl<T: ?Sized> RwLock<T> {
waiting_writers = ?state.waiting_writers,
"acquired {:?} lock on rwlock {:p}",
typ,
self.state
self
);

// Increment the current thread's clock and update this RwLock's clock to match.
Expand Down Expand Up @@ -283,7 +282,7 @@ impl<T: ?Sized> RwLock<T> {
waiting_writers = ?state.waiting_writers,
"trying to acquire {:?} lock on rwlock {:p}",
typ,
self.state,
self,
);

let acquired = match (typ, &mut state.holder) {
Expand All @@ -309,7 +308,7 @@ impl<T: ?Sized> RwLock<T> {
"{} {:?} lock on rwlock {:p}",
if acquired { "acquired" } else { "failed to acquire" },
typ,
self.state,
self,
);

// Update this thread's clock with the clock stored in the RwLock.
Expand Down Expand Up @@ -403,9 +402,9 @@ impl<T: ?Sized + Debug> Debug for RwLock<T> {

/// RAII structure used to release the shared read access of a `RwLock` when dropped.
pub struct RwLockReadGuard<'a, T: ?Sized> {
state: Rc<RefCell<RwLockState>>,
me: TaskId,
inner: Option<std::sync::RwLockReadGuard<'a, T>>,
rwlock: &'a RwLock<T>,
me: TaskId,
}

impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
Expand All @@ -432,14 +431,14 @@ impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
self.inner = None;

let mut state = self.state.borrow_mut();
let mut state = self.rwlock.state.borrow_mut();

trace!(
holder = ?state.holder,
waiting_readers = ?state.waiting_readers,
waiting_writers = ?state.waiting_writers,
"releasing Read lock on rwlock {:p}",
self.state
self.rwlock
);

match &mut state.holder {
Expand Down Expand Up @@ -471,7 +470,7 @@ impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
/// RAII structure used to release the exclusive write access of a `RwLock` when dropped.
pub struct RwLockWriteGuard<'a, T: ?Sized> {
inner: Option<std::sync::RwLockWriteGuard<'a, T>>,
state: Rc<RefCell<RwLockState>>,
rwlock: &'a RwLock<T>,
me: TaskId,
}

Expand Down Expand Up @@ -505,13 +504,13 @@ impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
self.inner = None;

let mut state = self.state.borrow_mut();
let mut state = self.rwlock.state.borrow_mut();
trace!(
holder = ?state.holder,
waiting_readers = ?state.waiting_readers,
waiting_writers = ?state.waiting_writers,
"releasing Write lock on rwlock {:p}",
self.state
self.rwlock
);

assert_eq!(state.holder, RwLockHolder::Write(self.me));
Expand Down

0 comments on commit dcdac00

Please sign in to comment.