diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index e4a198bd053..b27c6b02aaa 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -233,6 +233,10 @@ struct Context { /// True if a task panicked without being handled and the local set is /// configured to shutdown on unhandled panic. unhandled_panic: Cell, + + /// Should the currently running call to `tick` return after polling the + /// current future? + yield_now: Cell, } /// LocalSet state shared between threads. @@ -393,6 +397,7 @@ impl LocalSet { #[cfg(tokio_unstable)] unhandled_panic: crate::runtime::UnhandledPanic::Ignore, }), + yield_now: Cell::new(false), unhandled_panic: Cell::new(false), }), _not_send: PhantomData, @@ -598,10 +603,14 @@ impl LocalSet { /// Ticks the scheduler, returning whether the local future needs to be /// notified again. fn tick(&self) -> bool { + self.context.yield_now.set(false); + for _ in 0..MAX_TASKS_PER_TICK { - // Make sure we didn't hit an unhandled panic - if self.context.unhandled_panic.get() { - panic!("a spawned task panicked and the LocalSet is configured to shutdown on unhandled panic"); + // If yield_now is set, then the task we polled in the previous + // iteration waked itself. In this case, we should yield to the + // scheduler immediately. + if self.context.yield_now.get() || self.context.unhandled_panic.get() { + break; } match self.next_task() { @@ -621,6 +630,11 @@ impl LocalSet { } } + // Make sure we didn't hit an unhandled panic. + if self.context.unhandled_panic.get() { + panic!("a spawned task panicked and the LocalSet is configured to shutdown on unhandled panic"); + } + true } @@ -896,7 +910,7 @@ impl Context { }; if let Some(notified) = notified { - self.shared.schedule(notified); + self.shared.schedule(notified, false); } handle @@ -938,41 +952,42 @@ impl Future for RunUntil<'_, T> { impl Shared { /// Schedule the provided task on the scheduler. - fn schedule(&self, task: task::Notified>) { + fn schedule(&self, task: task::Notified>, yield_now: bool) { CURRENT.with(|localdata| { - match localdata.ctx.get() { - Some(cx) if cx.shared.ptr_eq(self) => unsafe { - // Safety: if the current `LocalSet` context points to this - // `LocalSet`, then we are on the thread that owns it. - cx.shared.local_state.task_push_back(task); - }, - - // We are on the thread that owns the `LocalSet`, so we can - // wake to the local queue. - _ if localdata.get_id() == Some(self.local_state.owner) => { - unsafe { - // Safety: we just checked that the thread ID matches - // the localset's owner, so this is safe. - self.local_state.task_push_back(task); + if localdata.get_id() == Some(self.local_state.owner) { + unsafe { + // Safety: we just checked that the thread ID matches + // the localset's owner, so this is safe. + self.local_state.task_push_back(task); + } + + let mut should_wake = true; + if let Some(cx) = localdata.ctx.get() { + if cx.shared.ptr_eq(self) { + should_wake = false; + // If the future waked itself, then we should return + // from tick. + cx.yield_now.set(yield_now | cx.yield_now.get()); } + } + if should_wake { // We still have to wake the `LocalSet`, because it isn't // currently being polled. self.waker.wake(); } - + } else { // We are *not* on the thread that owns the `LocalSet`, so we // have to wake to the remote queue. - _ => { - // First, check whether the queue is still there (if not, the - // LocalSet is dropped). Then push to it if so, and if not, - // do nothing. - let mut lock = self.queue.lock(); - - if let Some(queue) = lock.as_mut() { - queue.push_back(task); - drop(lock); - self.waker.wake(); - } + // + // First, check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); } } }); @@ -994,7 +1009,11 @@ impl task::Schedule for Arc { } fn schedule(&self, task: task::Notified) { - Shared::schedule(self, task); + Shared::schedule(self, task, false); + } + + fn yield_now(&self, task: task::Notified) { + Shared::schedule(self, task, true); } cfg_unstable! { diff --git a/tokio/tests/task_local_set.rs b/tokio/tests/task_local_set.rs index 2da87f5aed2..43fda5c68f2 100644 --- a/tokio/tests/task_local_set.rs +++ b/tokio/tests/task_local_set.rs @@ -612,6 +612,40 @@ fn store_local_set_in_thread_local_with_runtime() { }); } +#[test] +fn test_yield_now() { + use std::task::Poll; + + static IS_OK: AtomicBool = AtomicBool::new(false); + + let mut set = LocalSet::new(); + let rt = rt(); + + let jh = set.spawn_local(async { + // If poll once, then it is ok. + IS_OK.store(true, Ordering::SeqCst); + + tokio::task::yield_now().await; + + // If polled twice, then it is no longer ok. + IS_OK.store(false, Ordering::SeqCst); + }); + + // Poll the set once. + // + // Since the task wakes itself, the LocalSet should only poll it once. + assert!(rt + .block_on(futures::future::poll_fn(|cx| Poll::Ready( + set.poll_unpin(cx) + ))) + .is_pending()); + // This cancels the future assuming that it was polled only once. + drop(set); + + assert!(rt.block_on(jh).unwrap_err().is_cancelled()); + assert!(IS_OK.load(Ordering::SeqCst)); +} + #[cfg(tokio_unstable)] mod unstable { use tokio::runtime::UnhandledPanic;