Skip to content

Commit

Permalink
Scatter by worker instead of worker->nthreads (#8590)
Browse files Browse the repository at this point in the history
* Scatter round-robin by worker

Not by worker->nthreads

* Refactor requiring nthreads to scatter_to_workers
  • Loading branch information
milesgranger committed Apr 15, 2024
1 parent 0f2290b commit 42c479f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
5 changes: 2 additions & 3 deletions distributed/client.py
Expand Up @@ -2450,10 +2450,9 @@ async def _scatter(
nthreads = await self.scheduler.ncores_running(workers=workers)
if not nthreads: # pragma: no cover
raise ValueError("No valid workers found")
workers = list(nthreads.keys())

_, who_has, nbytes = await scatter_to_workers(
nthreads, data2, rpc=self.rpc
)
_, who_has, nbytes = await scatter_to_workers(workers, data2, self.rpc)

await self.scheduler.update_data(
who_has=who_has, nbytes=nbytes, client=self.id
Expand Down
7 changes: 3 additions & 4 deletions distributed/scheduler.py
Expand Up @@ -6132,16 +6132,15 @@ async def scatter(
raise TimeoutError("No valid workers found")
await asyncio.sleep(0.1)

nthreads = {ws.address: ws.nthreads for ws in wss}

assert isinstance(data, dict)

keys, who_has, nbytes = await scatter_to_workers(nthreads, data, rpc=self.rpc)
workers = list(ws.address for ws in wss)
keys, who_has, nbytes = await scatter_to_workers(workers, data, rpc=self.rpc)

self.update_data(who_has=who_has, nbytes=nbytes, client=client)

if broadcast:
n = len(nthreads) if broadcast is True else broadcast
n = len(workers) if broadcast is True else broadcast
await self.replicate(keys=keys, workers=workers, n=n)

self.log_event(
Expand Down
11 changes: 4 additions & 7 deletions distributed/utils_comm.py
Expand Up @@ -9,7 +9,7 @@
from itertools import cycle
from typing import Any, TypeVar

from tlz import concat, drop, groupby, merge
from tlz import drop, groupby, merge

import dask.config
from dask.optimization import SubgraphCallable
Expand Down Expand Up @@ -151,19 +151,16 @@ def __repr__(self):
_round_robin_counter = [0]


async def scatter_to_workers(nthreads, data, rpc=rpc):
async def scatter_to_workers(workers, data, rpc=rpc):
"""Scatter data directly to workers
This distributes data in a round-robin fashion to a set of workers based on
how many cores they have. nthreads should be a dictionary mapping worker
identities to numbers of cores.
This distributes data in a round-robin fashion to a set of workers.
See scatter for parameter docstring
"""
assert isinstance(nthreads, dict)
assert isinstance(data, dict)

workers = list(concat([w] * nc for w, nc in nthreads.items()))
workers = sorted(workers)
names, data = list(zip(*data.items()))

worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers))
Expand Down

0 comments on commit 42c479f

Please sign in to comment.