diff --git a/tokio/src/sync/notify.rs b/tokio/src/sync/notify.rs index 74b97cc481c..787147b36ab 100644 --- a/tokio/src/sync/notify.rs +++ b/tokio/src/sync/notify.rs @@ -551,6 +551,10 @@ impl Future for Notified<'_> { return Poll::Ready(()); } + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = cx.waker().clone(); + // Acquire the lock and attempt to transition to the waiting // state. let mut waiters = notify.waiters.lock(); @@ -612,7 +616,7 @@ impl Future for Notified<'_> { // Safety: called while locked. unsafe { - (*waiter.get()).waker = Some(cx.waker().clone()); + (*waiter.get()).waker = Some(waker); } // Insert the waiter into the linked list diff --git a/tokio/src/sync/tests/mod.rs b/tokio/src/sync/tests/mod.rs index c5d5601961d..ee76418ac59 100644 --- a/tokio/src/sync/tests/mod.rs +++ b/tokio/src/sync/tests/mod.rs @@ -1,5 +1,6 @@ cfg_not_loom! { mod atomic_waker; + mod notify; mod semaphore_batch; } diff --git a/tokio/src/sync/tests/notify.rs b/tokio/src/sync/tests/notify.rs new file mode 100644 index 00000000000..8c9a5735c67 --- /dev/null +++ b/tokio/src/sync/tests/notify.rs @@ -0,0 +1,44 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +}