diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 48c8ad9e850..c674b961d01 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -1,11 +1,12 @@ use crate::park::{Park, Unpark}; use crate::task::{self, JoinHandle, Schedule, ScheduleSendOnly, Task}; -use std::cell::UnsafeCell; +use std::cell::{Cell, UnsafeCell}; use std::collections::VecDeque; use std::fmt; use std::future::Future; use std::mem::ManuallyDrop; +use std::ptr; use std::sync::{Arc, Mutex}; use std::task::{RawWaker, RawWakerVTable, Waker}; use std::time::Duration; @@ -87,6 +88,10 @@ const MAX_TASKS_PER_TICK: usize = 61; /// How often to check the remote queue first const CHECK_REMOTE_INTERVAL: u8 = 13; +thread_local! { + static ACTIVE: Cell<*const SchedulerPriv> = Cell::new(ptr::null()) +} + impl

BasicScheduler

where P: Park, @@ -138,6 +143,27 @@ where let local = &mut self.local; let scheduler = &*self.scheduler; + struct Guard { + old: *const SchedulerPriv, + } + + impl Drop for Guard { + fn drop(&mut self) { + ACTIVE.with(|cell| cell.set(self.old)); + } + } + + // Track the current scheduler + let _guard = ACTIVE.with(|cell| { + let guard = Guard { + old: cell.get(), + }; + + cell.set(scheduler as *const SchedulerPriv); + + guard + }); + runtime::global::with_basic_scheduler(scheduler, || { let mut _enter = runtime::enter(); @@ -283,9 +309,11 @@ impl Schedule for SchedulerPriv { } fn schedule(&self, task: Task) { - use crate::runtime::global; + let is_current = ACTIVE.with(|cell| { + cell.get() == self as *const SchedulerPriv + }); - if global::basic_scheduler_is_current(self) { + if is_current { unsafe { self.schedule_local(task) }; } else { let mut lock = self.remote_queue.lock().unwrap(); diff --git a/tokio/src/runtime/global.rs b/tokio/src/runtime/global.rs index 557be914a5d..c84f348b7c4 100644 --- a/tokio/src/runtime/global.rs +++ b/tokio/src/runtime/global.rs @@ -65,13 +65,6 @@ where ) } -pub(super) fn basic_scheduler_is_current(basic_scheduler: &basic_scheduler::SchedulerPriv) -> bool { - EXECUTOR.with(|current_executor| match current_executor.get() { - State::Basic(ptr) => ptr == basic_scheduler as *const _, - _ => false, - }) -} - cfg_rt_threaded! { use crate::runtime::thread_pool; diff --git a/tokio/tests/fs.rs b/tokio/tests/fs.rs new file mode 100644 index 00000000000..13c44c08d6a --- /dev/null +++ b/tokio/tests/fs.rs @@ -0,0 +1,20 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::fs; +use tokio_test::assert_ok; + +#[tokio::test] +async fn path_read_write() { + let temp = tempdir(); + let dir = temp.path(); + + assert_ok!(fs::write(dir.join("bar"), b"bytes").await); + let out = assert_ok!(fs::read(dir.join("bar")).await); + + assert_eq!(out, b"bytes"); +} + +fn tempdir() -> tempfile::TempDir { + tempfile::tempdir().unwrap() +} diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 1d57bd47ee1..a3b68f1c7c8 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -198,6 +198,21 @@ rt_test! { } } + #[test] + fn spawn_await_chain() { + let mut rt = rt(); + + let out = rt.block_on(async { + assert_ok!(tokio::spawn(async { + assert_ok!(tokio::spawn(async { + "hello" + }).await) + }).await) + }); + + assert_eq!(out, "hello"); + } + #[test] fn outstanding_tasks_dropped() { let mut rt = rt(); diff --git a/tokio/tests/task_blocking.rs b/tokio/tests/task_blocking.rs new file mode 100644 index 00000000000..4cd83d8a0d6 --- /dev/null +++ b/tokio/tests/task_blocking.rs @@ -0,0 +1,29 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::task; +use tokio_test::assert_ok; + +use std::thread; +use std::time::Duration; + +#[tokio::test] +async fn basic_blocking() { + // Run a few times + for _ in 0..100 { + let out = assert_ok!( + tokio::spawn(async { + assert_ok!( + task::spawn_blocking(|| { + thread::sleep(Duration::from_millis(5)); + "hello" + }) + .await + ) + }) + .await + ); + + assert_eq!(out, "hello"); + } +}