Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow owned data in MappedMutexGuard #290

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions core/src/spinwait.rs
Expand Up @@ -6,11 +6,13 @@
// copied, modified, or distributed except according to those terms.

use crate::thread_parker;
#[allow(deprecated)]
use std::sync::atomic::spin_loop_hint;

// Wastes some CPU time for the given number of iterations,
// using a hint to indicate to the CPU that we are spinning.
#[inline]
#[allow(deprecated)]
fn cpu_relax(iterations: u32) {
for _ in 0..iterations {
spin_loop_hint()
Expand Down
81 changes: 81 additions & 0 deletions lock_api/src/lib.rs
Expand Up @@ -114,3 +114,84 @@ pub use crate::remutex::*;

mod rwlock;
pub use crate::rwlock::*;

/// A "shim trait" to allow generalizing over functions which return some generic
/// type which may borrow elements of its arguments, but without specifying that
/// the return type is `&`, `&mut` or something else concrete. This allows using
/// HRTB to force a caller to supply a function which works for any lifetime,
/// and therefore avoids the caller relying on a specific lifetime for the
/// argument, which can cause UB if the inner data lives for the static lifetime.
///
/// It also allows the output type to depend on the input type, which is important
/// when using lifetimes in HRTBs but is not possible with the stable syntax for
/// the `Fn` traits.
pub trait FnOnceShim<'a, T: 'a> {
/// Equivalent to `std::ops::FnOnce::Output`.
type Output: 'a;

/// Equivalent to `std::ops::FnOnce::call`
fn call(self, input: T) -> Self::Output;
}

impl<'a, F, In, Out> FnOnceShim<'a, In> for F
where
F: FnOnce(In) -> Out,
In: 'a,
Out: 'a,
{
type Output = Out;

fn call(self, input: In) -> Self::Output {
self(input)
}
}

/// As `FnOnceShim`, but specialized for functions which return an `Option` (used
/// for `try_map`).
pub trait FnOnceOptionShim<'a, T: 'a> {
/// Equivalent to `std::ops::FnOnce::Output`.
type Output: 'a;

/// Equivalent to `std::ops::FnOnce::call`
fn call(self, input: T) -> Option<Self::Output>;
}

impl<'a, F, In, Out> FnOnceOptionShim<'a, In> for F
where
F: FnOnce(In) -> Option<Out>,
In: 'a,
Out: 'a,
{
type Output = Out;

fn call(self, input: In) -> Option<Self::Output> {
self(input)
}
}

/// As `FnOnceShim`, but specialized for functions which return an `Result` (used
/// for `try_map`).
pub trait FnOnceResultShim<'a, T: 'a> {
/// Equivalent to `std::ops::FnOnce::Output`.
type Output: 'a;
/// Equivalent to `std::ops::FnOnce::Output`.
type Error: 'a;

/// Equivalent to `std::ops::FnOnce::call`
fn call(self, input: T) -> Result<Self::Output, Self::Error>;
}

impl<'a, F, In, Out, Error> FnOnceResultShim<'a, In> for F
where
F: FnOnce(In) -> Result<Out, Error>,
In: 'a,
Out: 'a,
Error: 'a,
{
type Output = Out;
type Error = Error;

fn call(self, input: In) -> Result<Self::Output, Self::Error> {
self(input)
}
}
137 changes: 94 additions & 43 deletions lock_api/src/mutex.rs
Expand Up @@ -5,18 +5,21 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use core::cell::UnsafeCell;
use core::fmt;
use core::marker::PhantomData;
use core::mem;
use core::ops::{Deref, DerefMut};
use core::{
cell::UnsafeCell,
fmt,
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
ptr,
};

use crate::{FnOnceOptionShim, FnOnceResultShim, FnOnceShim};

#[cfg(feature = "arc_lock")]
use alloc::sync::Arc;
#[cfg(feature = "arc_lock")]
use core::mem::ManuallyDrop;
#[cfg(feature = "arc_lock")]
use core::ptr;

#[cfg(feature = "owning_ref")]
use owning_ref::StableAddress;
Expand Down Expand Up @@ -508,13 +511,18 @@ impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> MutexGuard<'a, R, T> {
/// used as `MutexGuard::map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn map<U: ?Sized, F>(s: Self, f: F) -> MappedMutexGuard<'a, R, U>
pub fn map<F>(
s: Self,
f: F,
) -> MappedMutexGuard<'a, R, <F as FnOnceShim<'a, &'a mut T>>::Output>
where
F: FnOnce(&mut T) -> &mut U,
for<'any> F: FnOnceShim<'any, &'any mut T>,
{
let raw = &s.mutex.raw;
let data = f(unsafe { &mut *s.mutex.data.get() });
let data = unsafe { &mut *s.mutex.data.get() };
mem::forget(s);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will invalidate the data reference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is true, forget doesn't invalidate anything. For example, Box::leak leaks the value and returns a borrow of any length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. This doesn't actually invalidate the borrow because s.mutex is a reference.


let data = f.call(data);
MappedMutexGuard {
raw,
data,
Expand All @@ -532,16 +540,25 @@ impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> MutexGuard<'a, R, T> {
/// used as `MutexGuard::try_map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn try_map<U: ?Sized, F>(s: Self, f: F) -> Result<MappedMutexGuard<'a, R, U>, Self>
pub fn try_map<F>(
s: Self,
f: F,
) -> Result<MappedMutexGuard<'a, R, <F as FnOnceOptionShim<'a, &'a mut T>>::Output>, Self>
where
F: FnOnce(&mut T) -> Option<&mut U>,
for<'any> F: FnOnceOptionShim<'any, &'any mut T>,
{
let raw = &s.mutex.raw;
let data = match f(unsafe { &mut *s.mutex.data.get() }) {
let data = unsafe { &mut *s.mutex.data.get() };

let data = match f.call(data) {
Some(data) => data,
None => return Err(s),
};

// We use `mem::forget` instead of `ManuallyDrop` because we want to drop `self` if
// `f` panicks. This is safe, as `self` must outlive the `&'a mut T` reference.
mem::forget(s);

Ok(MappedMutexGuard {
raw,
data,
Expand Down Expand Up @@ -663,8 +680,8 @@ impl<'a, R: RawMutex + 'a, T: fmt::Display + ?Sized + 'a> fmt::Display for Mutex
unsafe impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> StableAddress for MutexGuard<'a, R, T> {}

/// An RAII mutex guard returned by the `Arc` locking operations on `Mutex`.
///
/// This is similar to the `MutexGuard` struct, except instead of using a reference to unlock the `Mutex` it
///
/// This is similar to the `MutexGuard` struct, except instead of using a reference to unlock the `Mutex` it
/// uses an `Arc<Mutex>`. This has several advantages, most notably that it has an `'static` lifetime.
#[cfg(feature = "arc_lock")]
#[must_use = "if unused the Mutex will immediately unlock"]
Expand Down Expand Up @@ -713,7 +730,7 @@ impl<R: RawMutexFair, T: ?Sized> ArcMutexGuard<R, T> {

// SAFETY: make sure the Arc gets it reference decremented
let mut s = ManuallyDrop::new(s);
unsafe { ptr::drop_in_place(&mut s.mutex) };
unsafe { ptr::drop_in_place(&mut s.mutex) };
}

/// Temporarily unlocks the mutex to execute the given function.
Expand Down Expand Up @@ -780,10 +797,12 @@ impl<R: RawMutex, T: ?Sized> Drop for ArcMutexGuard<R, T> {
/// could introduce soundness issues if the locked object is modified by another
/// thread.
#[must_use = "if unused the Mutex will immediately unlock"]
pub struct MappedMutexGuard<'a, R: RawMutex, T: ?Sized> {
pub struct MappedMutexGuard<'a, R: RawMutex, T: ?Sized + 'a> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think T has to live as long as 'a, we can remove the 'a bound on T.

Only R needs the bound because we have a &'a R.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely doesn't need the 'a bound on the struct itself, and I think you're right that it isn't needed elsewhere either.

raw: &'a R,
data: *mut T,
marker: PhantomData<&'a mut T>,
// We use `&'a mut` to make this type invariant over `'a`
marker: PhantomData<&'a mut ()>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it actually need to be invariant here? I think we can safely remove this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it invariant because that was how it was in the original code, and it's the most conservative option. Without it this type would be variant, which is probably correct but there's a chance of it being subtly broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this made T invariant, not 'a (see https://doc.rust-lang.org/nomicon/subtyping.html#variance).

The proper way to get invariance on 'a is PhantomData<fn(&'a ()) -> &'a ()>

// `data` at the end so we can cast `MappedMutexGuard<'_, _, [T; N]>` to `MappedMutexGuard<'_, _, [T]>`.
data: T,
}

unsafe impl<'a, R: RawMutex + Sync + 'a, T: ?Sized + Sync + 'a> Sync
Expand All @@ -795,7 +814,7 @@ unsafe impl<'a, R: RawMutex + 'a, T: ?Sized + Send + 'a> Send for MappedMutexGua
{
}

impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> MappedMutexGuard<'a, R, T> {
impl<'a, R: RawMutex + 'a, T: 'a> MappedMutexGuard<'a, R, T> {
/// Makes a new `MappedMutexGuard` for a component of the locked data.
///
/// This operation cannot fail as the `MappedMutexGuard` passed
Expand All @@ -805,13 +824,27 @@ impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> MappedMutexGuard<'a, R, T> {
/// used as `MappedMutexGuard::map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn map<U: ?Sized, F>(s: Self, f: F) -> MappedMutexGuard<'a, R, U>
pub fn map<F>(s: Self, f: F) -> MappedMutexGuard<'a, R, <F as FnOnceShim<'a, T>>::Output>
where
F: FnOnce(&mut T) -> &mut U,
for<'any> F: FnOnceShim<'any, T>,
{
let raw = s.raw;
let data = f(unsafe { &mut *s.data });
mem::forget(s);
let (data, raw) = {
let s = mem::ManuallyDrop::new(s);
(unsafe { ptr::read(&s.data) }, s.raw)
};

// `panic::catch_unwind` isn't available in `core`, so we use a dummy guard to unlock
// the mutex in case of unwind.
let lock_guard: MappedMutexGuard<'a, R, ()> = MappedMutexGuard {
raw,
data: (),
marker: PhantomData,
};

let data = f.call(data);

mem::forget(lock_guard);

MappedMutexGuard {
raw,
data,
Expand All @@ -829,25 +862,47 @@ impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> MappedMutexGuard<'a, R, T> {
/// used as `MappedMutexGuard::try_map(...)`. A method would interfere with methods of
/// the same name on the contents of the locked data.
#[inline]
pub fn try_map<U: ?Sized, F>(s: Self, f: F) -> Result<MappedMutexGuard<'a, R, U>, Self>
pub fn try_map<F>(
s: Self,
f: F,
) -> Result<
MappedMutexGuard<'a, R, <F as FnOnceResultShim<'a, T>>::Output>,
MappedMutexGuard<'a, R, <F as FnOnceResultShim<'a, T>>::Error>,
>
where
F: FnOnce(&mut T) -> Option<&mut U>,
for<'any> F: FnOnceResultShim<'any, T>,
{
let raw = s.raw;
let data = match f(unsafe { &mut *s.data }) {
Some(data) => data,
None => return Err(s),
let (data, raw) = {
let s = mem::ManuallyDrop::new(s);
(unsafe { ptr::read(&s.data) }, s.raw)
};
mem::forget(s);
Ok(MappedMutexGuard {

// `panic::catch_unwind` isn't available in `core`, so we use a dummy guard to unlock
// the mutex in case of unwind.
let lock_guard: MappedMutexGuard<'a, R, ()> = MappedMutexGuard {
raw,
data: (),
marker: PhantomData,
};

let out = f.call(data);

mem::forget(lock_guard);

out.map(|data| MappedMutexGuard {
raw,
data,
marker: PhantomData,
})
.map_err(|data| MappedMutexGuard {
raw,
data,
marker: PhantomData,
})
}
}

impl<'a, R: RawMutexFair + 'a, T: ?Sized + 'a> MappedMutexGuard<'a, R, T> {
impl<'a, R: RawMutexFair + 'a, T: 'a> MappedMutexGuard<'a, R, T> {
/// Unlocks the mutex using a fair unlock protocol.
///
/// By default, mutexes are unfair and allow the current thread to re-lock
Expand All @@ -872,16 +927,17 @@ impl<'a, R: RawMutexFair + 'a, T: ?Sized + 'a> MappedMutexGuard<'a, R, T> {

impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> Deref for MappedMutexGuard<'a, R, T> {
type Target = T;

#[inline]
fn deref(&self) -> &T {
unsafe { &*self.data }
&self.data
}
}

impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> DerefMut for MappedMutexGuard<'a, R, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.data }
&mut self.data
}
}

Expand All @@ -895,19 +951,14 @@ impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> Drop for MappedMutexGuard<'a, R, T> {
}
}

impl<'a, R: RawMutex + 'a, T: fmt::Debug + ?Sized + 'a> fmt::Debug for MappedMutexGuard<'a, R, T> {
impl<'a, R: RawMutex + 'a, T: fmt::Debug + 'a> fmt::Debug for MappedMutexGuard<'a, R, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}

impl<'a, R: RawMutex + 'a, T: fmt::Display + ?Sized + 'a> fmt::Display
for MappedMutexGuard<'a, R, T>
{
impl<'a, R: RawMutex + 'a, T: fmt::Display + 'a> fmt::Display for MappedMutexGuard<'a, R, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}

#[cfg(feature = "owning_ref")]
unsafe impl<'a, R: RawMutex + 'a, T: ?Sized + 'a> StableAddress for MappedMutexGuard<'a, R, T> {}