Skip to content

Commit

Permalink
switch to multitask
Browse files Browse the repository at this point in the history
Signed-off-by: Marc-Antoine Perennou <Marc-Antoine@Perennou.com>
  • Loading branch information
Keruspe committed Jul 23, 2020
1 parent fcc220f commit 1a68e3d
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 19 deletions.
12 changes: 9 additions & 3 deletions Cargo.toml
Expand Up @@ -30,9 +30,9 @@ default = [
"futures-lite",
"kv-log-macro",
"log",
"multitask",
"num_cpus",
"pin-project-lite",
"smol",
]
docs = ["attributes", "unstable", "default"]
unstable = [
Expand All @@ -57,7 +57,7 @@ alloc = [
"futures-core/alloc",
"pin-project-lite",
]
tokio02 = ["smol/tokio02"]
tokio02 = ["tokio"]

[dependencies]
async-attributes = { version = "1.1.1", optional = true }
Expand All @@ -83,7 +83,7 @@ surf = { version = "1.0.3", optional = true }
async-io = { version = "0.1.5", optional = true }
blocking = { version = "0.5.0", optional = true }
futures-lite = { version = "0.1.8", optional = true }
smol = { version = "0.1.17", optional = true }
multitask = { version = "0.2.0", optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
Expand All @@ -93,6 +93,12 @@ futures-channel = { version = "0.3.4", optional = true }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3.10"

[dependencies.tokio]
version = "0.2"
default-features = false
features = ["rt-threaded"]
optional = true

[dev-dependencies]
femme = "1.3.0"
rand = "0.7.3"
Expand Down
14 changes: 7 additions & 7 deletions src/task/builder.rs
Expand Up @@ -7,7 +7,7 @@ use std::task::{Context, Poll};
use pin_project_lite::pin_project;

use crate::io;
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
use crate::task::{self, JoinHandle, Task, TaskLocalsWrapper};

/// Task builder that configures the settings of a new task.
#[derive(Debug, Default)]
Expand Down Expand Up @@ -61,9 +61,9 @@ impl Builder {
});

let task = wrapped.tag.task().clone();
let smol_task = smol::Task::spawn(wrapped).into();
let handle = task::executor::spawn(wrapped);

Ok(JoinHandle::new(smol_task, task))
Ok(JoinHandle::new(handle, task))
}

/// Spawns a task locally with the configured settings.
Expand All @@ -81,9 +81,9 @@ impl Builder {
});

let task = wrapped.tag.task().clone();
let smol_task = smol::Task::local(wrapped).into();
let handle = task::executor::local(wrapped);

Ok(JoinHandle::new(smol_task, task))
Ok(JoinHandle::new(handle, task))
}

/// Spawns a task locally with the configured settings.
Expand Down Expand Up @@ -166,8 +166,8 @@ impl Builder {
unsafe {
TaskLocalsWrapper::set_current(&wrapped.tag, || {
let res = if should_run {
// The first call should use run.
smol::run(wrapped)
// The first call should run the executor
task::executor::run(wrapped)
} else {
futures_lite::future::block_on(wrapped)
};
Expand Down
91 changes: 91 additions & 0 deletions src/task/executor.rs
@@ -0,0 +1,91 @@
use std::cell::RefCell;
use std::future::Future;
use std::task::{Context, Poll};

static GLOBAL_EXECUTOR: once_cell::sync::Lazy<multitask::Executor> = once_cell::sync::Lazy::new(multitask::Executor::new);

struct Executor {
local_executor: multitask::LocalExecutor,
parker: async_io::parking::Parker,
}

thread_local! {
static EXECUTOR: RefCell<Executor> = RefCell::new({
let (parker, unparker) = async_io::parking::pair();
let local_executor = multitask::LocalExecutor::new(move || unparker.unpark());
Executor { local_executor, parker }
});
}

pub(crate) fn spawn<F, T>(future: F) -> multitask::Task<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
GLOBAL_EXECUTOR.spawn(future)
}

#[cfg(feature = "unstable")]
pub(crate) fn local<F, T>(future: F) -> multitask::Task<T>
where
F: Future<Output = T> + 'static,
T: 'static,
{
EXECUTOR.with(|executor| executor.borrow().local_executor.spawn(future))
}

pub(crate) fn run<F, T>(future: F) -> T
where
F: Future<Output = T>,
{
enter(|| EXECUTOR.with(|executor| {
let executor = executor.borrow();
let unparker = executor.parker.unparker();
let global_ticker = GLOBAL_EXECUTOR.ticker(move || unparker.unpark());
let unparker = executor.parker.unparker();
let waker = async_task::waker_fn(move || unparker.unpark());
let cx = &mut Context::from_waker(&waker);
pin_utils::pin_mut!(future);
loop {
if let Poll::Ready(res) = future.as_mut().poll(cx) {
return res;
}
if let Ok(false) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| executor.local_executor.tick() || global_ticker.tick())) {
executor.parker.park();
}
}
}))
}

/// Enters the tokio context if the `tokio` feature is enabled.
fn enter<T>(f: impl FnOnce() -> T) -> T {
#[cfg(not(feature = "tokio02"))]
return f();

#[cfg(feature = "tokio02")]
{
use std::cell::Cell;
use tokio::runtime::Runtime;

thread_local! {
/// The level of nested `enter` calls we are in, to ensure that the outermost always
/// has a runtime spawned.
static NESTING: Cell<usize> = Cell::new(0);
}

/// The global tokio runtime.
static RT: once_cell::sync::Lazy<Runtime> = once_cell::sync::Lazy::new(|| Runtime::new().expect("cannot initialize tokio"));

NESTING.with(|nesting| {
let res = if nesting.get() == 0 {
nesting.replace(1);
RT.enter(f)
} else {
nesting.replace(nesting.get() + 1);
f()
};
nesting.replace(nesting.get() - 1);
res
})
}
}
21 changes: 12 additions & 9 deletions src/task/join_handle.rs
Expand Up @@ -18,7 +18,7 @@ pub struct JoinHandle<T> {
}

#[cfg(not(target_os = "unknown"))]
type InnerHandle<T> = async_task::JoinHandle<T, ()>;
type InnerHandle<T> = multitask::Task<T>;
#[cfg(target_arch = "wasm32")]
type InnerHandle<T> = futures_channel::oneshot::Receiver<T>;

Expand Down Expand Up @@ -54,8 +54,7 @@ impl<T> JoinHandle<T> {
#[cfg(not(target_os = "unknown"))]
pub async fn cancel(mut self) -> Option<T> {
let handle = self.handle.take().unwrap();
handle.cancel();
handle.await
handle.cancel().await
}

/// Cancel this task.
Expand All @@ -67,15 +66,19 @@ impl<T> JoinHandle<T> {
}
}

#[cfg(not(target_os = "unknown"))]
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.detach();
}
}
}

impl<T> Future for JoinHandle<T> {
type Output = T;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => {
Poll::Ready(output.expect("cannot await the result of a panicked task"))
}
}
Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx)
}
}
2 changes: 2 additions & 0 deletions src/task/mod.rs
Expand Up @@ -148,6 +148,8 @@ cfg_default! {
mod block_on;
mod builder;
mod current;
#[cfg(not(target_os = "unknown"))]
mod executor;
mod join_handle;
mod sleep;
#[cfg(not(target_os = "unknown"))]
Expand Down

0 comments on commit 1a68e3d

Please sign in to comment.