Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace dashmap with scc #109

Merged
merged 5 commits into from Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 9 additions & 20 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions serial_test/Cargo.toml
Expand Up @@ -20,7 +20,8 @@ log = { version = "0.4", optional = true }
futures = { version = "^0.3", default_features = false, features = [
"executor",
], optional = true}
dashmap = { version = "5"}
scc = { version = "2"}
env_logger = {version="0.10", optional=true}

[dev-dependencies]
itertools = "0.10"
Expand All @@ -32,6 +33,9 @@ default = ["logging", "async"]
## Switches on debug logging (and requires the `log` package)
logging = ["log"]

## Switches on debug with env_logger. Generally only needed by internal serial_test work.
test_logging = ["logging", "env_logger", "serial_test_derive/test_logging"]

## Enables async features (and requires the `futures` package)
async = ["futures", "serial_test_derive/async"]

Expand All @@ -48,4 +52,4 @@ rustdoc-args = ["--cfg", "docsrs"]

[package.metadata.cargo-all-features]
skip_optional_dependencies = true
denylist = ["docsrs"]
denylist = ["docsrs", "test_logging"]
59 changes: 19 additions & 40 deletions serial_test/src/code_lock.rs
@@ -1,11 +1,7 @@
use crate::rwlock::{Locks, MutexGuardWrapper};
use dashmap::{try_result::TryResult, DashMap};
#[cfg(feature = "logging")]
use log::debug;
use once_cell::sync::OnceCell;
use scc::{hash_map::Entry, HashMap};
use std::sync::atomic::AtomicU32;
#[cfg(feature = "logging")]
use std::time::Instant;

#[derive(Clone)]
pub(crate) struct UniqueReentrantMutex {
Expand Down Expand Up @@ -41,51 +37,34 @@ impl UniqueReentrantMutex {
}

#[inline]
pub(crate) fn global_locks() -> &'static DashMap<String, UniqueReentrantMutex> {
static LOCKS: OnceCell<DashMap<String, UniqueReentrantMutex>> = OnceCell::new();
LOCKS.get_or_init(DashMap::new)
pub(crate) fn global_locks() -> &'static HashMap<String, UniqueReentrantMutex> {
#[cfg(feature = "test_logging")]
let _ = env_logger::builder().try_init();
static LOCKS: OnceCell<HashMap<String, UniqueReentrantMutex>> = OnceCell::new();
LOCKS.get_or_init(HashMap::new)
}

static MUTEX_ID: AtomicU32 = AtomicU32::new(1);

impl Default for UniqueReentrantMutex {
fn default() -> Self {
impl UniqueReentrantMutex {
fn new_mutex(name: &str) -> Self {
Self {
locks: Locks::new(),
locks: Locks::new(name),
id: MUTEX_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
}
}
}

pub(crate) fn check_new_key(name: &str) {
#[cfg(feature = "logging")]
let start = Instant::now();
loop {
#[cfg(feature = "logging")]
{
let duration = start.elapsed();
debug!("Waiting for '{}' {:?}", name, duration);
}
// Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else
match global_locks().try_get(name) {
TryResult::Present(_) => {
return;
}
TryResult::Locked => {
continue; // wasn't able to get read lock
}
TryResult::Absent => {} // do the write path below
};

// This is the rare path, which avoids the multi-writer situation mostly
let try_entry = global_locks().try_entry(name.to_string());
// Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else
if global_locks().contains(name) {
return;
};

if let Some(entry) = try_entry {
entry.or_default();
return;
}

// If the try_entry fails, then go around the loop again
// Odds are another test was also locking on the write and has now written the key
}
// This is the rare path, which avoids the multi-writer situation mostly
let entry = global_locks().entry(name.to_owned());
match entry {
Entry::Occupied(o) => o,
Entry::Vacant(v) => v.insert_entry(UniqueReentrantMutex::new_mutex(name)),
};
}
14 changes: 10 additions & 4 deletions serial_test/src/parallel_code_lock.rs
Expand Up @@ -5,14 +5,16 @@ use crate::code_lock::{check_new_key, global_locks};
use futures::FutureExt;
use std::panic;

fn get_locks(
names: Vec<&str>,
) -> Vec<dashmap::mapref::one::Ref<'_, String, crate::code_lock::UniqueReentrantMutex>> {
fn get_locks(names: Vec<&str>) -> Vec<crate::code_lock::UniqueReentrantMutex> {
names
.into_iter()
.map(|name| {
check_new_key(name);
global_locks().get(name).expect("key to be set")
global_locks()
.get(name)
.expect("key to be set")
.get()
.clone()
})
.collect::<Vec<_>>()
}
Expand Down Expand Up @@ -103,6 +105,7 @@ mod tests {
global_locks()
.get("unlock_on_assert_sync_without_return")
.unwrap()
.get()
.parallel_count(),
0
);
Expand All @@ -124,6 +127,7 @@ mod tests {
global_locks()
.get("unlock_on_assert_sync_with_return")
.unwrap()
.get()
.parallel_count(),
0
);
Expand Down Expand Up @@ -153,6 +157,7 @@ mod tests {
global_locks()
.get("unlock_on_assert_async_without_return")
.unwrap()
.get()
.parallel_count(),
0
);
Expand Down Expand Up @@ -186,6 +191,7 @@ mod tests {
global_locks()
.get("unlock_on_assert_async_with_return")
.unwrap()
.get()
.parallel_count(),
0
);
Expand Down
4 changes: 2 additions & 2 deletions serial_test/src/parallel_file_lock.rs
Expand Up @@ -102,10 +102,10 @@ mod tests {

#[test]
fn unlock_on_assert_sync_without_return() {
let lock_path = path_for_name("unlock_on_assert_sync_without_return");
let lock_path = path_for_name("parallel_unlock_on_assert_sync_without_return");
let _ = panic::catch_unwind(|| {
fs_parallel_core(
vec!["unlock_on_assert_sync_without_return"],
vec!["parallel_unlock_on_assert_sync_without_return"],
Some(&lock_path),
|| {
assert!(false);
Expand Down
34 changes: 33 additions & 1 deletion serial_test/src/rwlock.rs
@@ -1,3 +1,5 @@
#[cfg(feature = "logging")]
use log::debug;
use parking_lot::{Condvar, Mutex, ReentrantMutex, ReentrantMutexGuard};
use std::{sync::Arc, time::Duration};

Expand All @@ -14,6 +16,9 @@ struct LockData {
#[derive(Clone)]
pub(crate) struct Locks {
arc: Arc<LockData>,
// Name we're locking for (mostly test usage)
#[cfg(feature = "logging")]
pub(crate) name: String,
}

pub(crate) struct MutexGuardWrapper<'a> {
Expand All @@ -24,18 +29,23 @@ pub(crate) struct MutexGuardWrapper<'a> {

impl<'a> Drop for MutexGuardWrapper<'a> {
fn drop(&mut self) {
#[cfg(feature = "logging")]
debug!("End serial");
self.locks.arc.condvar.notify_one();
}
}

impl Locks {
pub fn new() -> Locks {
#[allow(unused_variables)]
pub fn new(name: &str) -> Locks {
Locks {
arc: Arc::new(LockData {
mutex: Mutex::new(LockState { parallels: 0 }),
condvar: Condvar::new(),
serial: Default::default(),
}),
#[cfg(feature = "logging")]
name: name.to_owned(),
}
}

Expand All @@ -45,16 +55,25 @@ impl Locks {
}

pub fn serial(&self) -> MutexGuardWrapper {
#[cfg(feature = "logging")]
debug!("Get serial lock '{}'", self.name);
let mut lock_state = self.arc.mutex.lock();
loop {
#[cfg(feature = "logging")]
debug!("Serial acquire {} {}", lock_state.parallels, self.name);
// If all the things we want are true, try to lock out serial
if lock_state.parallels == 0 {
let possible_serial_lock = self.arc.serial.try_lock();
if let Some(serial_lock) = possible_serial_lock {
#[cfg(feature = "logging")]
debug!("Got serial '{}'", self.name);
return MutexGuardWrapper {
mutex_guard: serial_lock,
locks: self.clone(),
};
} else {
#[cfg(feature = "logging")]
debug!("Someone else has serial '{}'", self.name);
}
}

Expand All @@ -65,8 +84,15 @@ impl Locks {
}

pub fn start_parallel(&self) {
#[cfg(feature = "logging")]
debug!("Get parallel lock '{}'", self.name);
let mut lock_state = self.arc.mutex.lock();
loop {
#[cfg(feature = "logging")]
debug!(
"Parallel, existing {} '{}'",
lock_state.parallels, self.name
);
if lock_state.parallels > 0 {
// fast path, as someone else already has it locked
lock_state.parallels += 1;
Expand All @@ -75,18 +101,24 @@ impl Locks {

let possible_serial_lock = self.arc.serial.try_lock();
if possible_serial_lock.is_some() {
#[cfg(feature = "logging")]
debug!("Parallel first '{}'", self.name);
// We now know no-one else has the serial lock, so we can add to parallel
lock_state.parallels = 1; // Had to have been 0 before, as otherwise we'd have hit the fast path
return;
}

#[cfg(feature = "logging")]
debug!("Parallel waiting '{}'", self.name);
self.arc
.condvar
.wait_for(&mut lock_state, Duration::from_secs(1));
}
}

pub fn end_parallel(&self) {
#[cfg(feature = "logging")]
debug!("End parallel '{}", self.name);
let mut lock_state = self.arc.mutex.lock();
assert!(lock_state.parallels > 0);
lock_state.parallels -= 1;
Expand Down
10 changes: 7 additions & 3 deletions serial_test/src/serial_code_lock.rs
Expand Up @@ -9,7 +9,11 @@ macro_rules! core_internal {
.into_iter()
.map(|name| {
check_new_key(name);
global_locks().get(name).expect("key to be set")
global_locks()
.get(name)
.expect("key to be set")
.get()
.clone()
})
.collect();
let _guards: Vec<_> = unlocks.iter().map(|unlock| unlock.lock()).collect();
Expand Down Expand Up @@ -84,7 +88,7 @@ mod tests {
check_new_key("foo");
{
let unlock = local_locks.get("foo").expect("read didn't work");
let mutex = unlock.value();
let mutex = unlock.get();

let mut ptr_guard = local_ptrs
.try_write_for(Duration::from_secs(1))
Expand Down Expand Up @@ -113,6 +117,6 @@ mod tests {
assert!(false);
})
});
assert!(!global_locks().get("assert").unwrap().is_locked());
assert!(!global_locks().get("assert").unwrap().get().is_locked());
}
}