Skip to content

Commit

Permalink
Positional-only args for get functions. They were already inconsistent
Browse files Browse the repository at this point in the history
with regards to argument name, this just enforces it and allows for type
checking of getters.
  • Loading branch information
Ian Rose committed May 10, 2022
1 parent 6f612af commit 04c1579
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
3 changes: 2 additions & 1 deletion dask/base.py
Expand Up @@ -29,6 +29,7 @@
from dask.core import literal, quote
from dask.hashing import hash_buffer_hex
from dask.system import CPU_COUNT
from dask.typing import SchedulerGetCallable
from dask.utils import Dispatch, apply, ensure_dict, key_split

__all__ = (
Expand Down Expand Up @@ -1238,7 +1239,7 @@ def _colorize(t):
return "#" + h


named_schedulers = {
named_schedulers: dict[str, SchedulerGetCallable] = {
"sync": local.get_sync,
"synchronous": local.get_sync,
"single-threaded": local.get_sync,
Expand Down
5 changes: 4 additions & 1 deletion dask/local.py
Expand Up @@ -106,7 +106,10 @@
See the function ``inline_functions`` for more information.
"""
from __future__ import annotations

import os
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, Future
from functools import partial
from queue import Empty, Queue
Expand Down Expand Up @@ -545,7 +548,7 @@ def submit(self, fn, *args, **kwargs):
synchronous_executor = SynchronousExecutor()


def get_sync(dsk, keys, **kwargs):
def get_sync(dsk: Mapping, keys: Sequence[Hashable] | Hashable, /, **kwargs):
"""A naive synchronous version of get_async
Can be useful for debugging.
Expand Down
6 changes: 4 additions & 2 deletions dask/multiprocessing.py
Expand Up @@ -7,6 +7,7 @@
import pickle
import sys
import traceback
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from warnings import warn
Expand Down Expand Up @@ -143,8 +144,9 @@ def get_context():


def get(
dsk,
keys,
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
/,
num_workers=None,
func_loads=None,
func_dumps=None,
Expand Down
11 changes: 10 additions & 1 deletion dask/threaded.py
Expand Up @@ -10,6 +10,7 @@
import sys
import threading
from collections import defaultdict
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from threading import Lock, current_thread

Expand All @@ -32,7 +33,15 @@ def pack_exception(e, dumps):
return e, sys.exc_info()[2]


def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
def get(
dsk: Mapping,
result: Sequence[Hashable] | Hashable,
/,
cache=None,
num_workers=None,
pool=None,
**kwargs,
):
"""Threaded cached implementation of dask.get
Parameters
Expand Down
3 changes: 2 additions & 1 deletion dask/typing.py
Expand Up @@ -20,8 +20,9 @@ class SchedulerGetCallable(Protocol):

def __call__(
self,
dask: Mapping,
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
/,
**kwargs: Any,
) -> Any:
"""Method called as the default scheduler for a collection.
Expand Down

0 comments on commit 04c1579

Please sign in to comment.