Skip to content

Commit

Permalink
Merge pull request #72 from palfrey/dashmap
Browse files Browse the repository at this point in the history
Redo locks with dashmap
  • Loading branch information
palfrey committed Aug 4, 2022
2 parents 9bee51f + 36edc84 commit 34d6995
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 57 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions serial_test/Cargo.toml
Expand Up @@ -20,6 +20,7 @@ log = { version = "0.4", optional = true }
futures = { version = "^0.3", default_features = false, features = [
"executor",
] }
dashmap = { version = "5"}

[dev-dependencies]
itertools = "0.10"
Expand Down
37 changes: 18 additions & 19 deletions serial_test/src/code_lock.rs
@@ -1,16 +1,15 @@
use crate::rwlock::{Locks, MutexGuardWrapper};
use dashmap::{try_result::TryResult, DashMap};
use lazy_static::lazy_static;
#[cfg(all(feature = "logging", feature = "timeout"))]
use log::debug;
#[cfg(feature = "timeout")]
use parking_lot::RwLock;
use std::sync::{atomic::AtomicU32, Arc};
#[cfg(feature = "timeout")]
use std::time::Duration;
#[cfg(feature = "timeout")]
use std::time::Instant;
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::{atomic::AtomicU32, Arc},
time::Duration,
};

pub(crate) struct UniqueReentrantMutex {
locks: Locks,
Expand Down Expand Up @@ -45,8 +44,8 @@ impl UniqueReentrantMutex {
}

lazy_static! {
pub(crate) static ref LOCKS: Arc<RwLock<HashMap<String, UniqueReentrantMutex>>> =
Arc::new(RwLock::new(HashMap::new()));
pub(crate) static ref LOCKS: Arc<DashMap<String, UniqueReentrantMutex>> =
Arc::new(DashMap::new());
static ref MUTEX_ID: Arc<AtomicU32> = Arc::new(AtomicU32::new(1));
}

Expand Down Expand Up @@ -93,25 +92,25 @@ pub(crate) fn check_new_key(name: &str) {
debug!("Waiting for '{}' {:?}", name, duration);
}
// Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else
let try_unlock = LOCKS.try_read_recursive_for(Duration::from_secs(1));
if let Some(unlock) = try_unlock {
if unlock.deref().contains_key(name) {
match LOCKS.try_get(name) {
TryResult::Present(_) => {
return;
}
drop(unlock); // so that we don't hold the read lock and so the writer can maybe succeed
} else {
continue; // wasn't able to get read lock
}
TryResult::Locked => {
continue; // wasn't able to get read lock
}
TryResult::Absent => {} // do the write path below
};

// This is the rare path, which avoids the multi-writer situation mostly
let try_lock = LOCKS.try_write_for(Duration::from_secs(1));
let try_entry = LOCKS.try_entry(name.to_string());

if let Some(mut lock) = try_lock {
lock.deref_mut().entry(name.to_string()).or_default();
if let Some(entry) = try_entry {
entry.or_default();
return;
}

// If the try_lock fails, then go around the loop again
// If the try_entry fails, then go around the loop again
// Odds are another test was also locking on the write and has now written the key

#[cfg(feature = "timeout")]
Expand Down
52 changes: 30 additions & 22 deletions serial_test/src/parallel_code_lock.rs
Expand Up @@ -2,7 +2,7 @@

use crate::code_lock::{check_new_key, LOCKS};
use futures::FutureExt;
use std::{ops::Deref, panic};
use std::panic;

#[doc(hidden)]
pub fn local_parallel_core_with_return<E>(
Expand All @@ -11,10 +11,10 @@ pub fn local_parallel_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(function);
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -27,12 +27,12 @@ pub fn local_parallel_core_with_return<E>(
pub fn local_parallel_core(name: &str, function: fn()) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(|| {
function();
});
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -45,10 +45,10 @@ pub async fn local_async_parallel_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -64,10 +64,10 @@ pub async fn local_async_parallel_core(
) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -79,7 +79,7 @@ mod tests {
code_lock::LOCKS, local_async_parallel_core, local_async_parallel_core_with_return,
local_parallel_core, local_parallel_core_with_return,
};
use std::{io::Error, ops::Deref, panic};
use std::{io::Error, panic};

#[test]
fn unlock_on_assert_sync_without_return() {
Expand All @@ -88,9 +88,11 @@ mod tests {
assert!(false);
})
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -106,9 +108,11 @@ mod tests {
},
)
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -127,9 +131,11 @@ mod tests {
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -156,9 +162,11 @@ mod tests {
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand Down
28 changes: 12 additions & 16 deletions serial_test/src/serial_code_lock.rs
@@ -1,7 +1,6 @@
#![allow(clippy::await_holding_lock)]

use crate::code_lock::{check_new_key, LOCKS};
use std::ops::Deref;

#[doc(hidden)]
pub fn local_serial_core_with_return<E>(
Expand All @@ -10,19 +9,19 @@ pub fn local_serial_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
function()
}

#[doc(hidden)]
pub fn local_serial_core(name: &str, function: fn()) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
function();
}

Expand All @@ -33,19 +32,20 @@ pub async fn local_async_serial_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
fut.await
}

#[doc(hidden)]
pub async fn local_async_serial_core(name: &str, fut: impl std::future::Future<Output = ()>) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();

fut.await;
}

Expand All @@ -57,7 +57,6 @@ mod tests {
use itertools::Itertools;
use parking_lot::RwLock;
use std::{
ops::Deref,
sync::{Arc, Barrier},
thread,
time::Duration,
Expand All @@ -79,10 +78,8 @@ mod tests {
c.wait();
check_new_key("foo");
{
let unlock = local_locks
.try_read_recursive_for(Duration::from_secs(1))
.expect("read lock didn't work");
let mutex = unlock.deref().get("foo").unwrap();
let unlock = local_locks.get("foo").expect("read didn't work");
let mutex = unlock.value();

let mut ptr_guard = local_ptrs
.try_write_for(Duration::from_secs(1))
Expand Down Expand Up @@ -111,7 +108,6 @@ mod tests {
assert!(false);
})
});
let unlock = LOCKS.read_recursive();
assert!(!unlock.deref()["assert"].is_locked());
assert!(!LOCKS.get("assert").unwrap().is_locked());
}
}

0 comments on commit 34d6995

Please sign in to comment.