Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask in pyodide #9053

Merged
merged 14 commits into from Jun 20, 2022
7 changes: 5 additions & 2 deletions dask/array/core.py
Expand Up @@ -31,7 +31,7 @@
from tlz import accumulate, concat, first, frequencies, groupby, partition
from tlz.curried import pluck

from dask import compute, config, core, threaded
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
Expand All @@ -49,6 +49,7 @@
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
Expand Down Expand Up @@ -84,6 +85,8 @@
)
from dask.widgets import get_template

DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

config.update_defaults({"array": {"chunk-size": "128MiB", "rechunk-threshold": 4}})

unknown_chunk_message = (
Expand Down Expand Up @@ -1406,7 +1409,7 @@ def __dask_tokenize__(self):
__dask_optimize__ = globalmethod(
optimize, key="array_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
15 changes: 11 additions & 4 deletions dask/bag/core.py
Expand Up @@ -40,13 +40,18 @@
from dask import config
from dask.bag import chunk
from dask.bag.avro import to_avro
from dask.base import DaskMethodsMixin, dont_optimize, replace_name_in_key, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
named_schedulers,
replace_name_in_key,
tokenize,
)
from dask.blockwise import blockwise
from dask.context import globalmethod
from dask.core import flatten, get_dependencies, istask, quote, reverse_dict
from dask.delayed import Delayed, unpack_collections
from dask.highlevelgraph import HighLevelGraph
from dask.multiprocessing import get as mpget
from dask.optimization import cull, fuse, inline
from dask.sizeof import sizeof
from dask.utils import (
Expand All @@ -64,6 +69,8 @@
takes_multiple_arguments,
)

DEFAULT_GET = named_schedulers.get("processes", named_schedulers["sync"])

no_default = "__no__default__"
no_result = type(
"no_result", (object,), {"__slots__": (), "__reduce__": lambda self: "no_result"}
Expand Down Expand Up @@ -371,7 +378,7 @@ def __dask_tokenize__(self):
return self.key

__dask_optimize__ = globalmethod(optimize, key="bag_optimize", falsey=dont_optimize)
__dask_scheduler__ = staticmethod(mpget)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize_item, ()
Expand Down Expand Up @@ -481,7 +488,7 @@ def __dask_tokenize__(self):
return self.name

__dask_optimize__ = globalmethod(optimize, key="bag_optimize", falsey=dont_optimize)
__dask_scheduler__ = staticmethod(mpget)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
24 changes: 15 additions & 9 deletions dask/base.py
Expand Up @@ -22,14 +22,15 @@
from tlz import curry, groupby, identity, merge
from tlz.functoolz import Compose

from dask import config, local, threaded
from dask.compatibility import _PY_VERSION
from dask import config, local
from dask.compatibility import _EMSCRIPTEN, _PY_VERSION
from dask.context import thread_state
from dask.core import flatten
from dask.core import get as simple_get
from dask.core import literal, quote
from dask.hashing import hash_buffer_hex
from dask.system import CPU_COUNT
from dask.typing import SchedulerGetCallable
from dask.utils import Dispatch, apply, ensure_dict, key_split

__all__ = (
Expand Down Expand Up @@ -1284,19 +1285,24 @@ def _colorize(t):
return "#" + h


named_schedulers = {
named_schedulers: dict[str, SchedulerGetCallable] = {
"sync": local.get_sync,
"synchronous": local.get_sync,
"single-threaded": local.get_sync,
"threads": threaded.get,
"threading": threaded.get,
}

try:
if not _EMSCRIPTEN:
from dask import threaded

named_schedulers.update(
{
"threads": threaded.get,
"threading": threaded.get,
}
)

from dask import multiprocessing as dask_multiprocessing
except ImportError:
pass
ian-r-rose marked this conversation as resolved.
Show resolved Hide resolved
else:

named_schedulers.update(
{
"processes": dask_multiprocessing.get,
Expand Down
2 changes: 2 additions & 0 deletions dask/compatibility.py
Expand Up @@ -3,3 +3,5 @@
from packaging.version import parse as parse_version

_PY_VERSION = parse_version(".".join(map(str, sys.version_info[:3])))

_EMSCRIPTEN = sys.platform == "emscripten"
16 changes: 12 additions & 4 deletions dask/dataframe/core.py
Expand Up @@ -20,10 +20,16 @@
from tlz import first, merge, partition_all, remove, unique

import dask.array as da
from dask import core, threaded
from dask import core
from dask.array.core import Array, normalize_arg
from dask.bag import map_partitions as map_bag_partitions
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
tokenize,
)
from dask.blockwise import Blockwise, BlockwiseDep, BlockwiseDepDict, blockwise
from dask.context import globalmethod
from dask.dataframe import methods
Expand Down Expand Up @@ -79,6 +85,8 @@
)
from dask.widgets import get_template

DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

no_default = "__no_default__"

GROUP_KEYS_DEFAULT = None if PANDAS_GT_150 else True
Expand Down Expand Up @@ -163,7 +171,7 @@ def __dask_layers__(self):
__dask_optimize__ = globalmethod(
optimize, key="dataframe_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return first, ()
Expand Down Expand Up @@ -345,7 +353,7 @@ def __dask_tokenize__(self):
__dask_optimize__ = globalmethod(
optimize, key="dataframe_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
13 changes: 10 additions & 3 deletions dask/dataframe/io/hdf.py
Expand Up @@ -8,14 +8,21 @@
from fsspec.utils import build_name_function, stringify_path
from tlz import merge

from dask import config, multiprocessing
from dask.base import compute_as_if_collection, get_scheduler, tokenize
from dask import config
from dask.base import (
compute_as_if_collection,
get_scheduler,
named_schedulers,
tokenize,
)
from dask.dataframe.core import DataFrame
from dask.dataframe.io.io import _link, from_map
from dask.dataframe.io.utils import DataFrameIOFunction
from dask.delayed import Delayed, delayed
from dask.utils import get_scheduler_lock

MP_GET = named_schedulers.get("processes", object())


def _pd_to_hdf(pd_to_hdf, lock, args, kwargs=None):
"""A wrapper function around pd_to_hdf that enables locking"""
Expand Down Expand Up @@ -193,7 +200,7 @@ def to_hdf(
if lock is None:
if not single_node:
lock = True
elif not single_file and _actual_get is not multiprocessing.get:
elif not single_file and _actual_get is not MP_GET:
# if we're writing to multiple files with the multiprocessing
# scheduler we don't need to lock
lock = True
Expand Down
8 changes: 6 additions & 2 deletions dask/delayed.py
Expand Up @@ -7,11 +7,12 @@

from tlz import concat, curry, merge, unique

from dask import config, threaded
from dask import config
from dask.base import (
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
replace_name_in_key,
)
from dask.base import tokenize as _tokenize
Expand All @@ -23,6 +24,9 @@
__all__ = ["Delayed", "delayed"]


DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])


def unzip(ls, nout):
"""Unzip a list of lists into ``nout`` outputs."""
out = list(zip(*ls))
Expand Down Expand Up @@ -518,7 +522,7 @@ def __dask_layers__(self):
def __dask_tokenize__(self):
return self.key

__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)
__dask_optimize__ = globalmethod(optimize, key="delayed_optimize")

def __dask_postcompute__(self):
Expand Down
5 changes: 4 additions & 1 deletion dask/local.py
Expand Up @@ -106,7 +106,10 @@

See the function ``inline_functions`` for more information.
"""
from __future__ import annotations

import os
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, Future
from functools import partial
from queue import Empty, Queue
Expand Down Expand Up @@ -545,7 +548,7 @@ def submit(self, fn, *args, **kwargs):
synchronous_executor = SynchronousExecutor()


def get_sync(dsk, keys, **kwargs):
def get_sync(dsk: Mapping, keys: Sequence[Hashable] | Hashable, **kwargs):
"""A naive synchronous version of get_async

Can be useful for debugging.
Expand Down
5 changes: 3 additions & 2 deletions dask/multiprocessing.py
Expand Up @@ -7,6 +7,7 @@
import pickle
import sys
import traceback
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from warnings import warn
ian-r-rose marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -143,8 +144,8 @@ def get_context():


def get(
dsk,
keys,
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
num_workers=None,
func_loads=None,
func_dumps=None,
Expand Down
57 changes: 57 additions & 0 deletions dask/tests/test_base.py
@@ -1,5 +1,6 @@
import dataclasses
import datetime
import inspect
import os
import subprocess
import sys
Expand Down Expand Up @@ -1539,3 +1540,59 @@ def __dask_optimize__(cls, dsk, keys, **kwargs):
)[0]
assert optimized
da.utils.assert_eq(x, result)


# A function designed to be run in a subprocess with dask.compatibility._EMSCRIPTEN
# patched. This allows for checking for different default schedulers depending on the
# platform. One might prefer patching `sys.platform` for a more direct test, but that
# causes problems in other libraries.
def check_default_scheduler(module, collection, expected, emscripten):
from contextlib import nullcontext
from unittest import mock

from dask.local import get_sync

if emscripten:
ctx = mock.patch("dask.base.named_schedulers", {"sync": get_sync})
else:
ctx = nullcontext()
with ctx:
import importlib

if expected == "sync":
from dask.local import get_sync as get
elif expected == "threads":
from dask.threaded import get
elif expected == "processes":
from dask.multiprocessing import get

mod = importlib.import_module(module)

assert getattr(mod, collection).__dask_scheduler__ == get


@pytest.mark.parametrize(
"params",
(
"'dask.dataframe', '_Frame', 'sync', True",
"'dask.dataframe', '_Frame', 'threads', False",
"'dask.array', 'Array', 'sync', True",
"'dask.array', 'Array', 'threads', False",
"'dask.bag', 'Bag', 'sync', True",
"'dask.bag', 'Bag', 'processes', False",
),
)
def test_emscripten_default_scheduler(params):
pytest.importorskip("dask.array")
pytest.importorskip("dask.dataframe")
proc = subprocess.run(
[
sys.executable,
"-c",
(
inspect.getsource(check_default_scheduler)
+ f"check_default_scheduler({params})\n"
),
]
)
proc.check_returncode()
14 changes: 11 additions & 3 deletions dask/threaded.py
Expand Up @@ -10,6 +10,7 @@
import sys
import threading
from collections import defaultdict
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from threading import Lock, current_thread

ian-r-rose marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -32,15 +33,22 @@ def pack_exception(e, dumps):
return e, sys.exc_info()[2]


def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
def get(
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
cache=None,
num_workers=None,
pool=None,
**kwargs,
):
"""Threaded cached implementation of dask.get

Parameters
----------

dsk: dict
A dask dictionary specifying a workflow
result: key or list of keys
keys: key or list of keys
Keys corresponding to desired data
num_workers: integer of thread count
The number of threads to use in the ThreadPool that will actually execute tasks
Expand Down Expand Up @@ -82,7 +90,7 @@ def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
pool.submit,
pool._max_workers,
dsk,
result,
keys,
cache=cache,
get_id=_thread_get_id,
pack_exception=pack_exception,
Expand Down