diff --git a/futures-util/src/task/waker.rs b/futures-util/src/task/waker.rs index 5fb2f9210a..2d9b3070b3 100644 --- a/futures-util/src/task/waker.rs +++ b/futures-util/src/task/waker.rs @@ -22,12 +22,10 @@ where // code here. We should guard against this by aborting. unsafe fn increase_refcount(data: *const ()) { - // Retain Arc by creating a copy - let arc: Arc = Arc::from_raw(data as *const T); - let arc_clone = arc.clone(); - // Forget the Arcs again, so that the refcount isn't decrased - mem::forget(arc); - mem::forget(arc_clone); + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = mem::ManuallyDrop::new(Arc::::from_raw(data as *const T)); + // Now increase refcount, but don't drop new refcount either + let _arc_clone: mem::ManuallyDrop<_> = arc.clone(); } // used by `waker_ref` @@ -43,9 +41,9 @@ unsafe fn wake_arc_raw(data: *const ()) { // used by `waker_ref` pub(super) unsafe fn wake_by_ref_arc_raw(data: *const ()) { - let arc: Arc = Arc::from_raw(data as *const T); + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = mem::ManuallyDrop::new(Arc::::from_raw(data as *const T)); ArcWake::wake_by_ref(&arc); - mem::forget(arc); } unsafe fn drop_arc_raw(data: *const ()) { diff --git a/futures/tests/arc_wake.rs b/futures/tests/arc_wake.rs index ff07ab9469..aa7a3fc6be 100644 --- a/futures/tests/arc_wake.rs +++ b/futures/tests/arc_wake.rs @@ -44,3 +44,22 @@ fn create_waker_from_arc() { drop(w1); assert_eq!(1, Arc::strong_count(&some_w)); } + +struct PanicWaker; + +impl ArcWake for PanicWaker { + fn wake_by_ref(_arc_self: &Arc) { + panic!("WAKE UP"); + } +} + +#[test] +fn proper_refcount_on_wake_panic() { + let some_w = Arc::new(PanicWaker); + + let w1: Waker = task::waker(some_w.clone()); + assert_eq!("WAKE UP", *std::panic::catch_unwind(|| w1.wake_by_ref()).unwrap_err().downcast::<&str>().unwrap()); + assert_eq!(2, Arc::strong_count(&some_w)); // some_w + w1 + drop(w1); + assert_eq!(1, Arc::strong_count(&some_w)); // some_w +}