Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo locks with dashmap #72

Merged
merged 4 commits into from Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions serial_test/Cargo.toml
Expand Up @@ -20,6 +20,7 @@ log = { version = "0.4", optional = true }
futures = { version = "^0.3", default_features = false, features = [
"executor",
] }
dashmap = { version = "5"}

[dev-dependencies]
itertools = "0.10"
Expand Down
37 changes: 18 additions & 19 deletions serial_test/src/code_lock.rs
@@ -1,16 +1,15 @@
use crate::rwlock::{Locks, MutexGuardWrapper};
use dashmap::{try_result::TryResult, DashMap};
use lazy_static::lazy_static;
#[cfg(all(feature = "logging", feature = "timeout"))]
use log::debug;
#[cfg(feature = "timeout")]
use parking_lot::RwLock;
use std::sync::{atomic::AtomicU32, Arc};
#[cfg(feature = "timeout")]
use std::time::Duration;
#[cfg(feature = "timeout")]
use std::time::Instant;
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::{atomic::AtomicU32, Arc},
time::Duration,
};

pub(crate) struct UniqueReentrantMutex {
locks: Locks,
Expand Down Expand Up @@ -45,8 +44,8 @@ impl UniqueReentrantMutex {
}

lazy_static! {
pub(crate) static ref LOCKS: Arc<RwLock<HashMap<String, UniqueReentrantMutex>>> =
Arc::new(RwLock::new(HashMap::new()));
pub(crate) static ref LOCKS: Arc<DashMap<String, UniqueReentrantMutex>> =
Arc::new(DashMap::new());
static ref MUTEX_ID: Arc<AtomicU32> = Arc::new(AtomicU32::new(1));
}

Expand Down Expand Up @@ -93,25 +92,25 @@ pub(crate) fn check_new_key(name: &str) {
debug!("Waiting for '{}' {:?}", name, duration);
}
// Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else
let try_unlock = LOCKS.try_read_recursive_for(Duration::from_secs(1));
if let Some(unlock) = try_unlock {
if unlock.deref().contains_key(name) {
match LOCKS.try_get(name) {
TryResult::Present(_) => {
return;
}
drop(unlock); // so that we don't hold the read lock and so the writer can maybe succeed
} else {
continue; // wasn't able to get read lock
}
TryResult::Locked => {
continue; // wasn't able to get read lock
}
TryResult::Absent => {} // do the write path below
};

// This is the rare path, which avoids the multi-writer situation mostly
let try_lock = LOCKS.try_write_for(Duration::from_secs(1));
let try_entry = LOCKS.try_entry(name.to_string());

if let Some(mut lock) = try_lock {
lock.deref_mut().entry(name.to_string()).or_default();
if let Some(entry) = try_entry {
entry.or_default();
return;
}

// If the try_lock fails, then go around the loop again
// If the try_entry fails, then go around the loop again
// Odds are another test was also locking on the write and has now written the key

#[cfg(feature = "timeout")]
Expand Down
52 changes: 30 additions & 22 deletions serial_test/src/parallel_code_lock.rs
Expand Up @@ -2,7 +2,7 @@

use crate::code_lock::{check_new_key, LOCKS};
use futures::FutureExt;
use std::{ops::Deref, panic};
use std::panic;

#[doc(hidden)]
pub fn local_parallel_core_with_return<E>(
Expand All @@ -11,10 +11,10 @@ pub fn local_parallel_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(function);
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -27,12 +27,12 @@ pub fn local_parallel_core_with_return<E>(
pub fn local_parallel_core(name: &str, function: fn()) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(|| {
function();
});
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -45,10 +45,10 @@ pub async fn local_async_parallel_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -64,10 +64,10 @@ pub async fn local_async_parallel_core(
) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -79,7 +79,7 @@ mod tests {
code_lock::LOCKS, local_async_parallel_core, local_async_parallel_core_with_return,
local_parallel_core, local_parallel_core_with_return,
};
use std::{io::Error, ops::Deref, panic};
use std::{io::Error, panic};

#[test]
fn unlock_on_assert_sync_without_return() {
Expand All @@ -88,9 +88,11 @@ mod tests {
assert!(false);
})
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -106,9 +108,11 @@ mod tests {
},
)
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -127,9 +131,11 @@ mod tests {
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -156,9 +162,11 @@ mod tests {
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand Down
28 changes: 12 additions & 16 deletions serial_test/src/serial_code_lock.rs
@@ -1,7 +1,6 @@
#![allow(clippy::await_holding_lock)]

use crate::code_lock::{check_new_key, LOCKS};
use std::ops::Deref;

#[doc(hidden)]
pub fn local_serial_core_with_return<E>(
Expand All @@ -10,19 +9,19 @@ pub fn local_serial_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
function()
}

#[doc(hidden)]
pub fn local_serial_core(name: &str, function: fn()) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
function();
}

Expand All @@ -33,19 +32,20 @@ pub async fn local_async_serial_core_with_return<E>(
) -> Result<(), E> {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();
fut.await
}

#[doc(hidden)]
pub async fn local_async_serial_core(name: &str, fut: impl std::future::Future<Output = ()>) {
check_new_key(name);

let unlock = LOCKS.read_recursive();
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.deref()[name].lock();
let _guard = unlock.lock();

fut.await;
}

Expand All @@ -57,7 +57,6 @@ mod tests {
use itertools::Itertools;
use parking_lot::RwLock;
use std::{
ops::Deref,
sync::{Arc, Barrier},
thread,
time::Duration,
Expand All @@ -79,10 +78,8 @@ mod tests {
c.wait();
check_new_key("foo");
{
let unlock = local_locks
.try_read_recursive_for(Duration::from_secs(1))
.expect("read lock didn't work");
let mutex = unlock.deref().get("foo").unwrap();
let unlock = local_locks.get("foo").expect("read didn't work");
let mutex = unlock.value();

let mut ptr_guard = local_ptrs
.try_write_for(Duration::from_secs(1))
Expand Down Expand Up @@ -111,7 +108,6 @@ mod tests {
assert!(false);
})
});
let unlock = LOCKS.read_recursive();
assert!(!unlock.deref()["assert"].is_locked());
assert!(!LOCKS.get("assert").unwrap().is_locked());
}
}