diff --git a/futures-util/src/lock/mutex.rs b/futures-util/src/lock/mutex.rs index a78de6283c..e136cc4fa1 100644 --- a/futures-util/src/lock/mutex.rs +++ b/futures-util/src/lock/mutex.rs @@ -8,6 +8,12 @@ use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::collections::VecDeque; + +struct Inner { + slab: Slab, + queue: VecDeque, +} /// A futures-aware mutex. /// @@ -19,7 +25,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; /// indefinitely. pub struct Mutex { state: AtomicUsize, - waiters: StdMutex>, + waiters: StdMutex, value: UnsafeCell, } @@ -51,17 +57,24 @@ enum Waiter { } impl Waiter { - fn register(&mut self, waker: &Waker) { + fn register(&mut self, waker: &Waker) -> bool { match self { - Self::Waiting(w) if waker.will_wake(w) => {}, - _ => *self = Self::Waiting(waker.clone()), + Self::Waiting(w) if waker.will_wake(w) => false, + Self::Waiting(_) => { + *self = Self::Waiting(waker.clone()); + false + } + Self::Woken => { + *self = Self::Waiting(waker.clone()); + true + } } } fn wake(&mut self) { match mem::replace(self, Self::Woken) { Self::Waiting(waker) => waker.wake(), - Self::Woken => {}, + Self::Woken => {} } } } @@ -75,7 +88,7 @@ impl Mutex { pub fn new(t: T) -> Self { Self { state: AtomicUsize::new(0), - waiters: StdMutex::new(Slab::new()), + waiters: StdMutex::new(Inner { slab: Slab::new(), queue: VecDeque::new() }), value: UnsafeCell::new(t), } } @@ -144,20 +157,23 @@ impl Mutex { fn remove_waker(&self, wait_key: usize, wake_another: bool) { if wait_key != WAIT_KEY_NONE { let mut waiters = self.waiters.lock().unwrap(); - match waiters.remove(wait_key) { - Waiter::Waiting(_) => {}, + match waiters.slab.remove(wait_key) { + Waiter::Waiting(_) => {} Waiter::Woken => { // We were awoken, but then dropped before we could // wake up to acquire the lock. Wake up another // waiter. if wake_another { - if let Some((_i, waiter)) = waiters.iter_mut().next() { - waiter.wake(); + while let Some(other_key) = waiters.queue.pop_front() { + if waiters.slab.contains(other_key) { + waiters.slab.remove(other_key).wake(); + break; + } } } } } - if waiters.is_empty() { + if waiters.slab.is_empty() { self.state.fetch_and(!HAS_WAITERS, Ordering::Relaxed); // released by mutex unlock } } @@ -169,8 +185,11 @@ impl Mutex { let old_state = self.state.fetch_and(!IS_LOCKED, Ordering::AcqRel); if (old_state & HAS_WAITERS) != 0 { let mut waiters = self.waiters.lock().unwrap(); - if let Some((_i, waiter)) = waiters.iter_mut().next() { - waiter.wake(); + while let Some(wait_key) = waiters.queue.pop_front() { + if waiters.slab.contains(wait_key) { + waiters.slab[wait_key].wake(); + break; + } } } } @@ -192,12 +211,12 @@ impl fmt::Debug for MutexLockFuture<'_, T> { .field("was_acquired", &self.mutex.is_none()) .field("mutex", &self.mutex) .field("wait_key", &( - if self.wait_key == WAIT_KEY_NONE { - None - } else { - Some(self.wait_key) - } - )) + if self.wait_key == WAIT_KEY_NONE { + None + } else { + Some(self.wait_key) + } + )) .finish() } } @@ -223,12 +242,15 @@ impl<'a, T: ?Sized> Future for MutexLockFuture<'a, T> { { let mut waiters = mutex.waiters.lock().unwrap(); if self.wait_key == WAIT_KEY_NONE { - self.wait_key = waiters.insert(Waiter::Waiting(cx.waker().clone())); - if waiters.len() == 1 { + self.wait_key = waiters.slab.insert(Waiter::Waiting(cx.waker().clone())); + waiters.queue.push_back(self.wait_key); + if waiters.slab.len() == 1 { mutex.state.fetch_or(HAS_WAITERS, Ordering::Relaxed); // released by mutex unlock } } else { - waiters[self.wait_key].register(cx.waker()); + if waiters.slab[self.wait_key].register(cx.waker()) { + waiters.queue.push_back(self.wait_key); + } } } @@ -281,8 +303,8 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { /// ``` #[inline] pub fn map(this: Self, f: F) -> MappedMutexGuard<'a, T, U> - where - F: FnOnce(&mut T) -> &mut U, + where + F: FnOnce(&mut T) -> &mut U, { let mutex = this.mutex; let value = f(unsafe { &mut *this.mutex.value.get() }); @@ -348,8 +370,8 @@ impl<'a, T: ?Sized, U: ?Sized> MappedMutexGuard<'a, T, U> { /// ``` #[inline] pub fn map(this: Self, f: F) -> MappedMutexGuard<'a, T, V> - where - F: FnOnce(&mut U) -> &mut V, + where + F: FnOnce(&mut U) -> &mut V, { let mutex = this.mutex; let value = f(unsafe { &mut *this.value }); @@ -391,19 +413,24 @@ impl DerefMut for MappedMutexGuard<'_, T, U> { // Mutexes can be moved freely between threads and acquired on any thread so long // as the inner value can be safely sent between threads. unsafe impl Send for Mutex {} + unsafe impl Sync for Mutex {} // It's safe to switch which thread the acquire is being attempted on so long as // `T` can be accessed on that thread. unsafe impl Send for MutexLockFuture<'_, T> {} + // doesn't have any interesting `&self` methods (only Debug) unsafe impl Sync for MutexLockFuture<'_, T> {} // Safe to send since we don't track any thread-specific details-- the inner // lock is essentially spinlock-equivalent (attempt to flip an atomic bool) unsafe impl Send for MutexGuard<'_, T> {} + unsafe impl Sync for MutexGuard<'_, T> {} + unsafe impl Send for MappedMutexGuard<'_, T, U> {} + unsafe impl Sync for MappedMutexGuard<'_, T, U> {} #[test] diff --git a/futures/tests/lock_mutex.rs b/futures/tests/lock_mutex.rs index 7c33864c76..08f5f00d00 100644 --- a/futures/tests/lock_mutex.rs +++ b/futures/tests/lock_mutex.rs @@ -7,6 +7,9 @@ use futures::task::{Context, SpawnExt}; use futures_test::future::FutureTestExt; use futures_test::task::{new_count_waker, panic_context}; use std::sync::Arc; +use futures::stream::futures_unordered::FuturesUnordered; +use std::time::Instant; + #[test] fn mutex_acquire_uncontested() { @@ -53,7 +56,7 @@ fn mutex_contested() { tx.unbounded_send(()).unwrap(); drop(lock); }) - .unwrap(); + .unwrap(); } block_on(async { @@ -64,3 +67,18 @@ fn mutex_contested() { assert_eq!(num_tasks, *lock); }) } + +#[test] +fn quadratic_performance_test() { + for &count in &[10, 100, 1000, 10000, 100000, 1000000] { + let mutex = Mutex::new(()); + let start = Instant::now(); + block_on((0..count).map(|_| { + async { + let _guard = mutex.lock().await; + ready(()).pending_once().await; + } + }).collect::>().collect::<()>()); + println!("{}\t{:?}", count, start.elapsed()); + } +}