diff --git a/futures-util/src/compat/compat03as01.rs b/futures-util/src/compat/compat03as01.rs index a571826ec5..763d19b2be 100644 --- a/futures-util/src/compat/compat03as01.rs +++ b/futures-util/src/compat/compat03as01.rs @@ -201,9 +201,9 @@ impl Current { let ptr = current_to_ptr(self); let vtable = &RawWakerVTable::new(clone, wake, wake, drop); - unsafe { - WakerRef::new(task03::Waker::from_raw(RawWaker::new(ptr, vtable))) - } + WakerRef::new_unowned(std::mem::ManuallyDrop::new(unsafe { + task03::Waker::from_raw(RawWaker::new(ptr, vtable)) + })) } } diff --git a/futures-util/src/task/mod.rs b/futures-util/src/task/mod.rs index 2687f0d275..fdc70136e9 100644 --- a/futures-util/src/task/mod.rs +++ b/futures-util/src/task/mod.rs @@ -1,20 +1,6 @@ //! Task notification cfg_target_has_atomic! { - /// A macro for creating a `RawWaker` vtable for a type that implements - /// the `ArcWake` trait. - #[cfg(feature = "alloc")] - macro_rules! waker_vtable { - ($ty:ident) => { - &RawWakerVTable::new( - clone_arc_raw::<$ty>, - wake_arc_raw::<$ty>, - wake_by_ref_arc_raw::<$ty>, - drop_arc_raw::<$ty>, - ) - }; - } - #[cfg(feature = "alloc")] mod arc_wake; #[cfg(feature = "alloc")] diff --git a/futures-util/src/task/waker.rs b/futures-util/src/task/waker.rs index 2d9b3070b3..1ea2f2ef6c 100644 --- a/futures-util/src/task/waker.rs +++ b/futures-util/src/task/waker.rs @@ -3,6 +3,15 @@ use core::mem; use core::task::{Waker, RawWaker, RawWakerVTable}; use alloc::sync::Arc; +pub(super) fn waker_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_arc_raw::, + wake_arc_raw::, + wake_by_ref_arc_raw::, + drop_arc_raw::, + ) +} + /// Creates a [`Waker`] from an `Arc`. /// /// The returned [`Waker`] will call @@ -14,7 +23,7 @@ where let ptr = Arc::into_raw(wake) as *const (); unsafe { - Waker::from_raw(RawWaker::new(ptr, waker_vtable!(W))) + Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) } } @@ -29,9 +38,9 @@ unsafe fn increase_refcount(data: *const ()) { } // used by `waker_ref` -pub(super) unsafe fn clone_arc_raw(data: *const ()) -> RawWaker { +unsafe fn clone_arc_raw(data: *const ()) -> RawWaker { increase_refcount::(data); - RawWaker::new(data, waker_vtable!(T)) + RawWaker::new(data, waker_vtable::()) } unsafe fn wake_arc_raw(data: *const ()) { @@ -40,7 +49,7 @@ unsafe fn wake_arc_raw(data: *const ()) { } // used by `waker_ref` -pub(super) unsafe fn wake_by_ref_arc_raw(data: *const ()) { +unsafe fn wake_by_ref_arc_raw(data: *const ()) { // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop let arc = mem::ManuallyDrop::new(Arc::::from_raw(data as *const T)); ArcWake::wake_by_ref(&arc); diff --git a/futures-util/src/task/waker_ref.rs b/futures-util/src/task/waker_ref.rs index 9e83742f54..8dda499a62 100644 --- a/futures-util/src/task/waker_ref.rs +++ b/futures-util/src/task/waker_ref.rs @@ -1,9 +1,10 @@ -use super::arc_wake::ArcWake; -use super::waker::{clone_arc_raw, wake_by_ref_arc_raw}; +use super::arc_wake::{ArcWake}; +use super::waker::waker_vtable; use alloc::sync::Arc; +use core::mem::ManuallyDrop; use core::marker::PhantomData; use core::ops::Deref; -use core::task::{Waker, RawWaker, RawWakerVTable}; +use core::task::{Waker, RawWaker}; /// A [`Waker`] that is only valid for a given lifetime. /// @@ -11,17 +12,29 @@ use core::task::{Waker, RawWaker, RawWakerVTable}; /// so it can be used to get a `&Waker`. #[derive(Debug)] pub struct WakerRef<'a> { - waker: Waker, + waker: ManuallyDrop, _marker: PhantomData<&'a ()>, } -impl WakerRef<'_> { - /// Create a new [`WakerRef`] from a [`Waker`]. +impl<'a> WakerRef<'a> { + /// Create a new [`WakerRef`] from a [`Waker`] reference. + pub fn new(waker: &'a Waker) -> Self { + // copy the underlying (raw) waker without calling a clone, + // as we won't call Waker::drop either. + let waker = ManuallyDrop::new(unsafe { core::ptr::read(waker) }); + WakerRef { + waker, + _marker: PhantomData, + } + } + + /// Create a new [`WakerRef`] from a [`Waker`] that must not be dropped. /// - /// Note: this function is safe, but it is generally only used - /// from `unsafe` contexts that need to create a `Waker` - /// that is guaranteed not to outlive a particular lifetime. - pub fn new(waker: Waker) -> Self { + /// Note: this if for rare cases where the caller created a [`Waker`] in + /// an unsafe way (that will be valid only for a lifetime to be determined + /// by the caller), and the [`Waker`] doesn't need to or must not be + /// destroyed. + pub fn new_unowned(waker: ManuallyDrop) -> Self { WakerRef { waker, _marker: PhantomData, @@ -37,21 +50,6 @@ impl Deref for WakerRef<'_> { } } -#[inline] -unsafe fn noop(_data: *const ()) {} - -unsafe fn wake_unreachable(_data: *const ()) { - // With only a reference, calling `wake_arc_raw()` would be unsound, - // since the `WakerRef` didn't increment the refcount of the `ArcWake`, - // and `wake_arc_raw` would *decrement* it. - // - // This should never be reachable, since `WakerRef` only provides a `Deref` - // to the inner `Waker`. - // - // Still, safer to panic here than to call `wake_arc_raw`. - unreachable!("WakerRef::wake"); -} - /// Creates a reference to a [`Waker`] from a reference to `Arc`. /// /// The resulting [`Waker`] will call @@ -61,21 +59,12 @@ pub fn waker_ref(wake: &Arc) -> WakerRef<'_> where W: ArcWake { - // This uses the same mechanism as Arc::into_raw, without needing a reference. - // This is potentially not stable - let ptr = &*wake as &W as *const W as *const (); - - // Similar to `waker_vtable`, but with a no-op `drop` function. - // Clones of the resulting `RawWaker` will still be dropped normally. - let vtable = &RawWakerVTable::new( - clone_arc_raw::, - wake_unreachable, - wake_by_ref_arc_raw::, - noop, - ); + // simply copy the pointer instead of using Arc::into_raw, + // as we don't actually keep a refcount by using ManuallyDrop.< + let ptr = (&**wake as *const W) as *const (); - let waker = unsafe { - Waker::from_raw(RawWaker::new(ptr, vtable)) - }; - WakerRef::new(waker) + let waker = ManuallyDrop::new(unsafe { + Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) + }); + WakerRef::new_unowned(waker) } diff --git a/futures/tests/arc_wake.rs b/futures/tests/arc_wake.rs index aa7a3fc6be..1940e4f98b 100644 --- a/futures/tests/arc_wake.rs +++ b/futures/tests/arc_wake.rs @@ -63,3 +63,15 @@ fn proper_refcount_on_wake_panic() { drop(w1); assert_eq!(1, Arc::strong_count(&some_w)); // some_w } + +#[test] +fn waker_ref_wake_same() { + let some_w = Arc::new(CountingWaker::new()); + + let w1: Waker = task::waker(some_w.clone()); + let w2 = task::waker_ref(&some_w); + let w3 = w2.clone(); + + assert!(w1.will_wake(&w2)); + assert!(w2.will_wake(&w3)); +}