diff --git a/Cargo.lock b/Cargo.lock index 3cbe3e9..afc40b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "dashmap" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8858831f7781322e539ea39e72449c46b059638250c14344fec8d0aa6e539c" +dependencies = [ + "cfg-if 1.0.0", + "num_cpus", + "parking_lot", +] + [[package]] name = "document-features" version = "0.2.1" @@ -277,6 +288,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num_cpus" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.10.0" @@ -439,6 +460,7 @@ dependencies = [ name = "serial_test" version = "0.8.0" dependencies = [ + "dashmap", "document-features", "fslock", "futures", diff --git a/serial_test/Cargo.toml b/serial_test/Cargo.toml index d2777e8..528fca2 100644 --- a/serial_test/Cargo.toml +++ b/serial_test/Cargo.toml @@ -20,6 +20,7 @@ log = { version = "0.4", optional = true } futures = { version = "^0.3", default_features = false, features = [ "executor", ] } +dashmap = { version = "5"} [dev-dependencies] itertools = "0.10" diff --git a/serial_test/src/code_lock.rs b/serial_test/src/code_lock.rs index 279b3fb..1fea4da 100644 --- a/serial_test/src/code_lock.rs +++ b/serial_test/src/code_lock.rs @@ -1,16 +1,15 @@ use crate::rwlock::{Locks, MutexGuardWrapper}; +use dashmap::{try_result::TryResult, DashMap}; use lazy_static::lazy_static; #[cfg(all(feature = "logging", feature = "timeout"))] use log::debug; +#[cfg(feature = "timeout")] use parking_lot::RwLock; +use std::sync::{atomic::AtomicU32, Arc}; +#[cfg(feature = "timeout")] +use std::time::Duration; #[cfg(feature = "timeout")] use std::time::Instant; -use std::{ - collections::HashMap, - ops::{Deref, DerefMut}, - sync::{atomic::AtomicU32, Arc}, - time::Duration, -}; pub(crate) struct UniqueReentrantMutex { locks: Locks, @@ -45,8 +44,8 @@ impl UniqueReentrantMutex { } lazy_static! { - pub(crate) static ref LOCKS: Arc>> = - Arc::new(RwLock::new(HashMap::new())); + pub(crate) static ref LOCKS: Arc> = + Arc::new(DashMap::new()); static ref MUTEX_ID: Arc = Arc::new(AtomicU32::new(1)); } @@ -93,25 +92,25 @@ pub(crate) fn check_new_key(name: &str) { debug!("Waiting for '{}' {:?}", name, duration); } // Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else - let try_unlock = LOCKS.try_read_recursive_for(Duration::from_secs(1)); - if let Some(unlock) = try_unlock { - if unlock.deref().contains_key(name) { + match LOCKS.try_get(name) { + TryResult::Present(_) => { return; } - drop(unlock); // so that we don't hold the read lock and so the writer can maybe succeed - } else { - continue; // wasn't able to get read lock - } + TryResult::Locked => { + continue; // wasn't able to get read lock + } + TryResult::Absent => {} // do the write path below + }; // This is the rare path, which avoids the multi-writer situation mostly - let try_lock = LOCKS.try_write_for(Duration::from_secs(1)); + let try_entry = LOCKS.try_entry(name.to_string()); - if let Some(mut lock) = try_lock { - lock.deref_mut().entry(name.to_string()).or_default(); + if let Some(entry) = try_entry { + entry.or_default(); return; } - // If the try_lock fails, then go around the loop again + // If the try_entry fails, then go around the loop again // Odds are another test was also locking on the write and has now written the key #[cfg(feature = "timeout")] diff --git a/serial_test/src/parallel_code_lock.rs b/serial_test/src/parallel_code_lock.rs index e854cec..202cc83 100644 --- a/serial_test/src/parallel_code_lock.rs +++ b/serial_test/src/parallel_code_lock.rs @@ -2,7 +2,7 @@ use crate::code_lock::{check_new_key, LOCKS}; use futures::FutureExt; -use std::{ops::Deref, panic}; +use std::panic; #[doc(hidden)] pub fn local_parallel_core_with_return( @@ -11,10 +11,10 @@ pub fn local_parallel_core_with_return( ) -> Result<(), E> { check_new_key(name); - let unlock = LOCKS.read_recursive(); - unlock.deref()[name].start_parallel(); + let lock = LOCKS.get(name).unwrap(); + lock.start_parallel(); let res = panic::catch_unwind(function); - unlock.deref()[name].end_parallel(); + lock.end_parallel(); match res { Ok(ret) => ret, Err(err) => { @@ -27,12 +27,12 @@ pub fn local_parallel_core_with_return( pub fn local_parallel_core(name: &str, function: fn()) { check_new_key(name); - let unlock = LOCKS.read_recursive(); - unlock.deref()[name].start_parallel(); + let lock = LOCKS.get(name).unwrap(); + lock.start_parallel(); let res = panic::catch_unwind(|| { function(); }); - unlock.deref()[name].end_parallel(); + lock.end_parallel(); if let Err(err) = res { panic::resume_unwind(err); } @@ -45,10 +45,10 @@ pub async fn local_async_parallel_core_with_return( ) -> Result<(), E> { check_new_key(name); - let unlock = LOCKS.read_recursive(); - unlock.deref()[name].start_parallel(); + let lock = LOCKS.get(name).unwrap(); + lock.start_parallel(); let res = fut.catch_unwind().await; - unlock.deref()[name].end_parallel(); + lock.end_parallel(); match res { Ok(ret) => ret, Err(err) => { @@ -64,10 +64,10 @@ pub async fn local_async_parallel_core( ) { check_new_key(name); - let unlock = LOCKS.read_recursive(); - unlock.deref()[name].start_parallel(); + let lock = LOCKS.get(name).unwrap(); + lock.start_parallel(); let res = fut.catch_unwind().await; - unlock.deref()[name].end_parallel(); + lock.end_parallel(); if let Err(err) = res { panic::resume_unwind(err); } @@ -79,7 +79,7 @@ mod tests { code_lock::LOCKS, local_async_parallel_core, local_async_parallel_core_with_return, local_parallel_core, local_parallel_core_with_return, }; - use std::{io::Error, ops::Deref, panic}; + use std::{io::Error, panic}; #[test] fn unlock_on_assert_sync_without_return() { @@ -88,9 +88,11 @@ mod tests { assert!(false); }) }); - let unlock = LOCKS.read_recursive(); assert_eq!( - unlock.deref()["unlock_on_assert_sync_without_return"].parallel_count(), + LOCKS + .get("unlock_on_assert_sync_without_return") + .unwrap() + .parallel_count(), 0 ); } @@ -106,9 +108,11 @@ mod tests { }, ) }); - let unlock = LOCKS.read_recursive(); assert_eq!( - unlock.deref()["unlock_on_assert_sync_with_return"].parallel_count(), + LOCKS + .get("unlock_on_assert_sync_with_return") + .unwrap() + .parallel_count(), 0 ); } @@ -127,9 +131,11 @@ mod tests { let _enter_guard = handle.enter(); futures::executor::block_on(call_serial_test_fn()); }); - let unlock = LOCKS.read_recursive(); assert_eq!( - unlock.deref()["unlock_on_assert_async_without_return"].parallel_count(), + LOCKS + .get("unlock_on_assert_async_without_return") + .unwrap() + .parallel_count(), 0 ); } @@ -156,9 +162,11 @@ mod tests { let _enter_guard = handle.enter(); futures::executor::block_on(call_serial_test_fn()); }); - let unlock = LOCKS.read_recursive(); assert_eq!( - unlock.deref()["unlock_on_assert_async_with_return"].parallel_count(), + LOCKS + .get("unlock_on_assert_async_with_return") + .unwrap() + .parallel_count(), 0 ); } diff --git a/serial_test/src/serial_code_lock.rs b/serial_test/src/serial_code_lock.rs index 5c494bd..4a328ca 100644 --- a/serial_test/src/serial_code_lock.rs +++ b/serial_test/src/serial_code_lock.rs @@ -1,7 +1,6 @@ #![allow(clippy::await_holding_lock)] use crate::code_lock::{check_new_key, LOCKS}; -use std::ops::Deref; #[doc(hidden)] pub fn local_serial_core_with_return( @@ -10,9 +9,9 @@ pub fn local_serial_core_with_return( ) -> Result<(), E> { check_new_key(name); - let unlock = LOCKS.read_recursive(); + let unlock = LOCKS.get(name).expect("key to be set"); // _guard needs to be named to avoid being instant dropped - let _guard = unlock.deref()[name].lock(); + let _guard = unlock.lock(); function() } @@ -20,9 +19,9 @@ pub fn local_serial_core_with_return( pub fn local_serial_core(name: &str, function: fn()) { check_new_key(name); - let unlock = LOCKS.read_recursive(); + let unlock = LOCKS.get(name).expect("key to be set"); // _guard needs to be named to avoid being instant dropped - let _guard = unlock.deref()[name].lock(); + let _guard = unlock.lock(); function(); } @@ -33,9 +32,9 @@ pub async fn local_async_serial_core_with_return( ) -> Result<(), E> { check_new_key(name); - let unlock = LOCKS.read_recursive(); + let unlock = LOCKS.get(name).expect("key to be set"); // _guard needs to be named to avoid being instant dropped - let _guard = unlock.deref()[name].lock(); + let _guard = unlock.lock(); fut.await } @@ -43,9 +42,10 @@ pub async fn local_async_serial_core_with_return( pub async fn local_async_serial_core(name: &str, fut: impl std::future::Future) { check_new_key(name); - let unlock = LOCKS.read_recursive(); + let unlock = LOCKS.get(name).expect("key to be set"); // _guard needs to be named to avoid being instant dropped - let _guard = unlock.deref()[name].lock(); + let _guard = unlock.lock(); + fut.await; } @@ -57,7 +57,6 @@ mod tests { use itertools::Itertools; use parking_lot::RwLock; use std::{ - ops::Deref, sync::{Arc, Barrier}, thread, time::Duration, @@ -79,10 +78,8 @@ mod tests { c.wait(); check_new_key("foo"); { - let unlock = local_locks - .try_read_recursive_for(Duration::from_secs(1)) - .expect("read lock didn't work"); - let mutex = unlock.deref().get("foo").unwrap(); + let unlock = local_locks.get("foo").expect("read didn't work"); + let mutex = unlock.value(); let mut ptr_guard = local_ptrs .try_write_for(Duration::from_secs(1)) @@ -111,7 +108,6 @@ mod tests { assert!(false); }) }); - let unlock = LOCKS.read_recursive(); - assert!(!unlock.deref()["assert"].is_locked()); + assert!(!LOCKS.get("assert").unwrap().is_locked()); } }