diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 8a6aabde47a..ceb97eac426 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -441,7 +441,7 @@ cfg_sync! { pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; mod rwlock; - pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard, MappedRwLockReadGuard, MappedRwLockWriteGuard}; mod task; pub(crate) use task::AtomicWaker; diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 68cf710e84b..8681493a238 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,6 +1,8 @@ use crate::coop::CoopFutureExt; -use crate::sync::batch_semaphore::{AcquireError, Semaphore}; +use crate::sync::batch_semaphore::Semaphore; use std::cell::UnsafeCell; +use std::fmt; +use std::marker; use std::ops; #[cfg(not(loom))] @@ -77,6 +79,99 @@ pub struct RwLock { c: UnsafeCell, } +/// An RAII read lock guard returned by [`RwLockReadGuard::map`], which can +/// point to a subfield of the protected data. +/// +/// [`RwLockReadGuard::map`]: method@RwLockReadGuard::map +pub struct MappedRwLockReadGuard<'a, T> { + s: &'a Semaphore, + data: *const T, + marker: marker::PhantomData<&'a T>, +} + +impl<'a, T> ops::Deref for MappedRwLockReadGuard<'a, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<'a, T> fmt::Debug for MappedRwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T> fmt::Display for MappedRwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T> Drop for MappedRwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1); + } +} + +/// An RAII write lock guard returned by [`RwLockWriteGuard::map`], which can +/// point to a subfield of the protected data. +/// +/// [`RwLockWriteGuard::map`]: method@RwLockWriteGuard::map +pub struct MappedRwLockWriteGuard<'a, T> { + s: &'a Semaphore, + data: *mut T, + marker: marker::PhantomData<&'a mut T>, +} + +impl<'a, T> ops::Deref for MappedRwLockWriteGuard<'a, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + unsafe { &*self.data } + } +} + +impl<'a, T> ops::DerefMut for MappedRwLockWriteGuard<'a, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.data } + } +} + +impl<'a, T> fmt::Debug for MappedRwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T> fmt::Display for MappedRwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T> Drop for MappedRwLockWriteGuard<'a, T> { + fn drop(&mut self) { + self.s.release(MAX_READS); + } +} + /// RAII structure used to release the shared read access of a lock when /// dropped. /// @@ -84,12 +179,117 @@ pub struct RwLock { /// [`RwLock`]. /// /// [`read`]: method@RwLock::read -#[derive(Debug)] +/// [`RwLock`]: struct@RwLock pub struct RwLockReadGuard<'a, T> { - permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } +impl<'a, T> RwLockReadGuard<'a, T> { + /// Make a new `MappedRwLockReadGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// let mapped = RwLockReadGuard::map(lock.read().await, |f| &f.0); + /// assert_eq!(1, *mapped); + /// # } + /// ``` + #[inline] + pub fn map(self, f: F) -> MappedRwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(unsafe { &*self.lock.c.get() }); + let s = &self.lock.s; + MappedRwLockReadGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedRwLockReadGuard`] for a component of the + /// locked data. The original guard is return if the closure returns `None`. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockReadGuard::map(..)`. A method would interfere with methods of the + /// same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// let mapped = RwLockReadGuard::try_map(lock.read().await, |f| Some(&f.0)).expect("should not fail"); + /// assert_eq!(1, *mapped); + /// # } + /// ``` + #[inline] + pub fn try_map(self, f: F) -> Result, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(unsafe { &*self.lock.c.get() }) { + Some(data) => data, + None => return Err(self), + }; + let s = &self.lock.s; + Ok(MappedRwLockReadGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T> fmt::Debug for RwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T> fmt::Display for RwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T> Drop for RwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.lock.s.release(1); + } +} + /// RAII structure used to release the exclusive write access of a lock when /// dropped. /// @@ -98,32 +298,124 @@ pub struct RwLockReadGuard<'a, T> { /// /// [`write`]: method@RwLock::write /// [`RwLock`]: struct@RwLock -#[derive(Debug)] pub struct RwLockWriteGuard<'a, T> { - permit: ReleasingPermit<'a, T>, lock: &'a RwLock, } -// Wrapper arround Permit that releases on Drop -#[derive(Debug)] -struct ReleasingPermit<'a, T> { - num_permits: u16, - lock: &'a RwLock, +impl<'a, T> RwLockWriteGuard<'a, T> { + /// Make a new `MappedRwLockWriteGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map(self, f: F) -> MappedRwLockWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(unsafe { &mut *self.lock.c.get() }); + let s = &self.lock.s; + MappedRwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedRwLockWriteGuard`] for a component of + /// the locked data. The original guard is return if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::try_map(lock.write().await, |f| Some(&mut f.0)).expect("should not fail"); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map(self, f: F) -> Result, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(unsafe { &mut *self.lock.c.get() }) { + Some(data) => data, + None => return Err(self), + }; + let s = &self.lock.s; + Ok(MappedRwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + }) + } } -impl<'a, T> ReleasingPermit<'a, T> { - async fn acquire( - lock: &'a RwLock, - num_permits: u16, - ) -> Result, AcquireError> { - lock.s.acquire(num_permits).cooperate().await?; - Ok(Self { num_permits, lock }) +impl<'a, T> fmt::Debug for RwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) } } -impl<'a, T> Drop for ReleasingPermit<'a, T> { +impl<'a, T> fmt::Display for RwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T> Drop for RwLockWriteGuard<'a, T> { fn drop(&mut self) { - self.lock.s.release(self.num_permits as usize); + self.lock.s.release(MAX_READS); } } @@ -158,6 +450,8 @@ unsafe impl Send for RwLock where T: Send {} unsafe impl Sync for RwLock where T: Send + Sync {} unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {} unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {} +unsafe impl<'a, T> Sync for MappedRwLockReadGuard<'a, T> where T: Send + Sync {} +unsafe impl<'a, T> Sync for MappedRwLockWriteGuard<'a, T> where T: Send + Sync {} impl RwLock { /// Creates a new instance of an `RwLock` which is unlocked. @@ -208,12 +502,12 @@ impl RwLock { ///} /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| { + self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); - RwLockReadGuard { lock: self, permit } + RwLockReadGuard { lock: self } } /// Locks this rwlock with exclusive write access, causing the current task @@ -239,15 +533,16 @@ impl RwLock { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, MAX_READS as u16) + self.s + .acquire(MAX_READS as u16) + .cooperate() .await .unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); - - RwLockWriteGuard { lock: self, permit } + RwLockWriteGuard { lock: self } } /// Consumes the lock, returning the underlying data.