diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index 21e44ca932c..15a418d096b 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -161,13 +161,22 @@ unsafe impl Sync for Mutex where T: ?Sized + Send {} unsafe impl Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} unsafe impl Sync for OwnedMutexGuard where T: ?Sized + Send + Sync {} -/// Error returned from the [`Mutex::try_lock`] function. +/// Error returned from the [`Mutex::try_lock`], [`RwLock::try_read`] and +/// [`RwLock::try_write`] functions. /// -/// A `try_lock` operation can only fail if the mutex is already locked. +/// `Mutex::try_lock` operation will only fail if the mutex is already locked. +/// +/// `RwLock::try_read` operation will only fail if the lock is currently held +/// by an exclusive writer. +/// +/// `RwLock::try_write` operation will if lock is held by any reader or by an +/// exclusive writer. /// /// [`Mutex::try_lock`]: Mutex::try_lock +/// [`RwLock::try_read`]: fn@super::RwLock::try_read +/// [`RwLock::try_write`]: fn@super::RwLock::try_write #[derive(Debug)] -pub struct TryLockError(()); +pub struct TryLockError(pub(super) ()); impl fmt::Display for TryLockError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 2e72cf75d85..bafae8f4b57 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,4 +1,5 @@ -use crate::sync::batch_semaphore::Semaphore; +use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; +use crate::sync::mutex::TryLockError; use std::cell::UnsafeCell; use std::fmt; use std::marker; @@ -422,7 +423,7 @@ impl RwLock { /// // While main has an active read lock, we acquire one too. /// let r = c_lock.read().await; /// assert_eq!(*r, 1); - /// }).await.expect("The spawned task has paniced"); + /// }).await.expect("The spawned task has panicked"); /// /// // Drop the guard after the spawned task finishes. /// drop(n); @@ -441,6 +442,52 @@ impl RwLock { } } + /// Attempts to acquire this `RwLock` with shared read access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release read access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// let c_lock = lock.clone(); + /// + /// let v = lock.try_read().unwrap(); + /// assert_eq!(*v, 1); + /// + /// tokio::spawn(async move { + /// // While main has an active read lock, we acquire one too. + /// let n = c_lock.read().await; + /// assert_eq!(*n, 1); + /// }).await.expect("The spawned task has panicked"); + /// + /// // Drop the guard when spawned task finishes. + /// drop(v); + /// } + /// ``` + pub fn try_read(&self) -> Result, TryLockError> { + match self.s.try_acquire(1) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + Ok(RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + }) + } + /// Locks this rwlock with exclusive write access, causing the current task /// to yield until the lock has been acquired. /// @@ -476,6 +523,43 @@ impl RwLock { } } + /// Attempts to acquire this `RwLock` with exclusive write access. + /// + /// If the access couldn't be acquired immediately, returns [`TryLockError`]. + /// Otherwise, an RAII guard is returned which will release write access + /// when dropped. + /// + /// [`TryLockError`]: TryLockError + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rw = RwLock::new(1); + /// + /// let v = rw.read().await; + /// assert_eq!(*v, 1); + /// + /// assert!(rw.try_write().is_err()); + /// } + /// ``` + pub fn try_write(&self) -> Result, TryLockError> { + match self.s.try_acquire(MAX_READS as u32) { + Ok(permit) => permit, + Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), + Err(TryAcquireError::Closed) => unreachable!(), + } + + Ok(RwLockWriteGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + }) + } + /// Returns a mutable reference to the underlying data. /// /// Since this call borrows the `RwLock` mutably, no actual locking needs to diff --git a/tokio/tests/sync_rwlock.rs b/tokio/tests/sync_rwlock.rs index 76760351680..872b845cfdd 100644 --- a/tokio/tests/sync_rwlock.rs +++ b/tokio/tests/sync_rwlock.rs @@ -235,3 +235,36 @@ async fn multithreaded() { let g = rwlock.read().await; assert_eq!(*g, 17_000); } + +#[tokio::test] +async fn try_write() { + let lock = RwLock::new(0); + let read_guard = lock.read().await; + assert!(lock.try_write().is_err()); + drop(read_guard); + assert!(lock.try_write().is_ok()); +} + +#[test] +fn try_read_try_write() { + let lock: RwLock = RwLock::new(15); + + { + let rg1 = lock.try_read().unwrap(); + assert_eq!(*rg1, 15); + + assert!(lock.try_write().is_err()); + + let rg2 = lock.try_read().unwrap(); + assert_eq!(*rg2, 15) + } + + { + let mut wg = lock.try_write().unwrap(); + *wg = 1515; + + assert!(lock.try_read().is_err()) + } + + assert_eq!(*lock.try_read().unwrap(), 1515); +}