From 731d13223c7af8d62d797a4df169898030f57a60 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 12:12:03 +0000 Subject: [PATCH 01/10] Fix flaky test_close_gracefully and test_lifetime --- distributed/tests/test_worker.py | 44 +++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index db580920d8a..3a4a784e765 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1627,32 +1627,52 @@ async def test_worker_listens_on_same_interface_by_default(cleanup, Worker): @gen_cluster(client=True) async def test_close_gracefully(c, s, a, b): - futures = c.map(slowinc, range(200), delay=0.1) + futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) - while not b.data: + # Note: keys will appear in b.data several milliseconds before they switch to + # status=memory in s.tasks. It's important to sample the in-memory keys from the + # scheduler side, because those that the scheduler thinks are still processing won't + # be replicated by the Active Memory Manager. + while True: + mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} + if len(mem) >= 8: + break await asyncio.sleep(0.01) - mem = set(b.data) - proc = {ts for ts in b.tasks.values() if ts.state == "executing"} - assert proc + + assert any(ts for ts in b.tasks.values() if ts.state == "executing") await b.close_gracefully() assert b.status == Status.closed assert b.address not in s.workers - assert mem.issubset(a.data.keys()) - for ts in proc: - assert ts.state in ("executing", "memory") + + # All tasks that were in memory in b have been copied over to a; + # they have not been recomputed + for key in mem: + assert_worker_story( + a.story(key), + [ + (key, "put-in-memory"), + (key, "receive-from-scatter"), + ], + strict=True, + ) + assert key in a.data @pytest.mark.slow @gen_cluster(client=True, nthreads=[]) async def test_lifetime(c, s): async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b: - futures = c.map(slowinc, range(200), delay=0.1, worker=[b.address]) - await asyncio.sleep(1.5) - assert b.status not in (Status.running, Status.paused) + futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) + await asyncio.sleep(0.5) + assert not a.data + assert b.data + b_keys = set(b.data) + while b.status == Status.running: + await asyncio.sleep(0.01) await b.finished() - assert set(b.data) == set(a.data) # successfully moved data over + assert b_keys.issubset(a.data) # successfully moved data over from b to a @gen_cluster(worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"}) From 940bb45ba0ddffe042100dcbcaa69c2408931a88 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 12:20:10 +0000 Subject: [PATCH 02/10] tweak comment --- distributed/tests/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3a4a784e765..83fbe98cdaf 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1632,7 +1632,7 @@ async def test_close_gracefully(c, s, a, b): # Note: keys will appear in b.data several milliseconds before they switch to # status=memory in s.tasks. It's important to sample the in-memory keys from the # scheduler side, because those that the scheduler thinks are still processing won't - # be replicated by the Active Memory Manager. + # be replicated by retire_workers(). while True: mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} if len(mem) >= 8: From aef3b719617571cd8b37cf23d8bd9b2a6819ceac Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 14:29:10 +0000 Subject: [PATCH 03/10] Increase resilience on slow CI --- distributed/tests/test_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 83fbe98cdaf..d9405ecfa35 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1663,9 +1663,9 @@ async def test_close_gracefully(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[]) async def test_lifetime(c, s): - async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b: + async with Worker(s.address) as a, Worker(s.address, lifetime="2 seconds") as b: futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) - await asyncio.sleep(0.5) + await asyncio.sleep(1) assert not a.data assert b.data b_keys = set(b.data) From 7faab519e9f3a133a02f36495ed2e77a7fe8fd3f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 14:39:41 +0000 Subject: [PATCH 04/10] harden test --- distributed/tests/test_worker.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d9405ecfa35..720e3dc0015 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1661,18 +1661,34 @@ async def test_close_gracefully(c, s, a, b): @pytest.mark.slow -@gen_cluster(client=True, nthreads=[]) -async def test_lifetime(c, s): - async with Worker(s.address) as a, Worker(s.address, lifetime="2 seconds") as b: +@gen_cluster(client=True, nthreads=[("", 1)], timeout=10) +async def test_lifetime(c, s, a): + async with Worker(s.address, lifetime="2 seconds") as b: futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) await asyncio.sleep(1) assert not a.data - assert b.data - b_keys = set(b.data) - while b.status == Status.running: + # Note: keys will appear in b.data several milliseconds before they switch to + # status=memory in s.tasks. It's important to sample the in-memory keys from the + # scheduler side, because those that the scheduler thinks are still processing + # won't be replicated by retire_workers(). + mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} + assert mem + + while b.status != Status.closed: await asyncio.sleep(0.01) - await b.finished() - assert b_keys.issubset(a.data) # successfully moved data over from b to a + + # All tasks that were in memory in b have been copied over to a; + # they have not been recomputed + for key in mem: + assert_worker_story( + a.story(key), + [ + (key, "put-in-memory"), + (key, "receive-from-scatter"), + ], + strict=True, + ) + assert key in a.data @gen_cluster(worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"}) From 0595052137ce0b883c87abcacc1fe0fca21ed01d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 17:25:03 +0000 Subject: [PATCH 05/10] Dump cluster state on all test failures (#5674) --- .github/workflows/tests.yaml | 6 +- .gitignore | 5 +- .../diagnostics/tests/test_worker_plugin.py | 29 +++---- distributed/tests/test_actor.py | 60 ++++++-------- distributed/tests/test_client.py | 78 +++++++++---------- distributed/tests/test_scheduler.py | 30 +++---- distributed/tests/test_steal.py | 19 ++--- distributed/utils_test.py | 46 +++++++++-- 8 files changed, 142 insertions(+), 131 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e4d823ac416..6755d5fed5a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -139,13 +139,13 @@ jobs: with: name: ${{ env.TEST_ID }} path: reports - - name: Upload timeout reports + - name: Upload gen_cluster dumps for failed tests # ensure this runs even if pytest fails if: > always() && (steps.run_tests.outcome == 'success' || steps.run_tests.outcome == 'failure') uses: actions/upload-artifact@v2 with: - name: ${{ env.TEST_ID }}-timeouts - path: test_timeout_dump + name: ${{ env.TEST_ID }}_cluster_dumps + path: test_cluster_dump if-no-files-found: ignore diff --git a/.gitignore b/.gitignore index 6f9792237e3..3be3f122893 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ dask-worker-space/ tags .ipynb_checkpoints .venv/ +.mypy_cache/ -# Test timeouts will dump the cluster state in here -test_timeout_dump/ +# Test failures will dump the cluster state in here +test_cluster_dump/ diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 0b0ca52e12b..c6dfd1d561d 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -207,7 +207,8 @@ class MyCustomPlugin(WorkerPlugin): assert next(iter(w.plugins)).startswith("MyCustomPlugin-") -def test_release_key_deprecated(): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_release_key_deprecated(c, s, a): class ReleaseKeyDeprecated(WorkerPlugin): def __init__(self): self._called = False @@ -222,20 +223,18 @@ def teardown(self, worker): assert self._called return super().teardown(worker) - @gen_cluster(client=True, nthreads=[("", 1)]) - async def test(c, s, a): - - await c.register_worker_plugin(ReleaseKeyDeprecated()) - fut = await c.submit(inc, 1, key="task") - assert fut == 2 + await c.register_worker_plugin(ReleaseKeyDeprecated()) with pytest.deprecated_call( match="The `WorkerPlugin.release_key` hook is depreacted" ): - test() + assert await c.submit(inc, 1, key="x") == 2 + while "x" in a.tasks: + await asyncio.sleep(0.01) -def test_assert_no_warning_no_overload(): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_assert_no_warning_no_overload(c, s, a): """Assert we do not receive a deprecation warning if we do not overload any methods """ @@ -243,15 +242,11 @@ def test_assert_no_warning_no_overload(): class Dummy(WorkerPlugin): pass - @gen_cluster(client=True, nthreads=[("", 1)]) - async def test(c, s, a): - - await c.register_worker_plugin(Dummy()) - fut = await c.submit(inc, 1, key="task") - assert fut == 2 - with pytest.warns(None): - test() + await c.register_worker_plugin(Dummy()) + assert await c.submit(inc, 1, key="x") == 2 + while "x" in a.tasks: + await asyncio.sleep(0.01) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index cf026fcf233..a84ee22bc38 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -71,13 +71,11 @@ def get(self, key): @pytest.mark.parametrize("direct_to_workers", [True, False]) -def test_client_actions(direct_to_workers): - @gen_cluster(client=True) - async def test(c, s, a, b): - c = await Client( - s.address, asynchronous=True, direct_to_workers=direct_to_workers - ) - +@gen_cluster() +async def test_client_actions(s, a, b, direct_to_workers): + async with Client( + s.address, asynchronous=True, direct_to_workers=direct_to_workers + ) as c: counter = c.submit(Counter, workers=[a.address], actor=True) assert isinstance(counter, Future) counter = await counter @@ -86,8 +84,7 @@ async def test(c, s, a, b): assert hasattr(counter, "add") assert hasattr(counter, "n") - n = await counter.n - assert n == 0 + assert await counter.n == 0 assert counter._address == a.address @@ -96,45 +93,36 @@ async def test(c, s, a, b): await asyncio.gather(counter.increment(), counter.increment()) - n = await counter.n - assert n == 2 + assert await counter.n == 2 counter.add(10) while (await counter.n) != 10 + 2: - n = await counter.n await asyncio.sleep(0.01) - await c.close() - - test() - @pytest.mark.parametrize("separate_thread", [False, True]) -def test_worker_actions(separate_thread): - @gen_cluster(client=True) - async def test(c, s, a, b): - counter = c.submit(Counter, workers=[a.address], actor=True) - a_address = a.address - - def f(counter): - start = counter.n +@gen_cluster(client=True) +async def test_worker_actions(c, s, a, b, separate_thread): + counter = c.submit(Counter, workers=[a.address], actor=True) + a_address = a.address - assert type(counter) is Actor - assert counter._address == a_address + def f(counter): + start = counter.n - future = counter.increment(separate_thread=separate_thread) - assert isinstance(future, ActorFuture) - assert "Future" in type(future).__name__ - end = future.result(timeout=1) - assert end > start + assert type(counter) is Actor + assert counter._address == a_address - futures = [c.submit(f, counter, pure=False) for _ in range(10)] - await c.gather(futures) + future = counter.increment(separate_thread=separate_thread) + assert isinstance(future, ActorFuture) + assert "Future" in type(future).__name__ + end = future.result(timeout=1) + assert end > start - counter = await counter - assert await counter.n == 10 + futures = [c.submit(f, counter, pure=False) for _ in range(10)] + await c.gather(futures) - test() + counter = await counter + assert await counter.n == 10 @gen_cluster(client=True) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fbdfec48ef3..3d316cee2d9 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5747,43 +5747,37 @@ async def test_client_active_bad_port(): @pytest.mark.parametrize("direct", [True, False]) -def test_turn_off_pickle(direct): - @gen_cluster() - async def test(s, a, b): - np = pytest.importorskip("numpy") - - async with Client( - s.address, asynchronous=True, serializers=["dask", "msgpack"] - ) as c: - assert (await c.submit(inc, 1)) == 2 - await c.submit(np.ones, 5) - await c.scatter(1) +@gen_cluster(client=True, client_kwargs={"serializers": ["dask", "msgpack"]}) +async def test_turn_off_pickle(c, s, a, b, direct): + np = pytest.importorskip("numpy") - # Can't send complex data - with pytest.raises(TypeError): - future = await c.scatter(inc) + assert (await c.submit(inc, 1)) == 2 + await c.submit(np.ones, 5) + await c.scatter(1) - # can send complex tasks (this uses pickle regardless) - future = c.submit(lambda x: x, inc) - await wait(future) + # Can't send complex data + with pytest.raises(TypeError): + await c.scatter(inc) - # but can't receive complex results - with pytest.raises(TypeError): - await c.gather(future, direct=direct) + # can send complex tasks (this uses pickle regardless) + future = c.submit(lambda x: x, inc) + await wait(future) - # Run works - result = await c.run(lambda: 1) - assert list(result.values()) == [1, 1] - result = await c.run_on_scheduler(lambda: 1) - assert result == 1 + # but can't receive complex results + with pytest.raises(TypeError): + await c.gather(future, direct=direct) - # But not with complex return values - with pytest.raises(TypeError): - await c.run(lambda: inc) - with pytest.raises(TypeError): - await c.run_on_scheduler(lambda: inc) + # Run works + result = await c.run(lambda: 1) + assert list(result.values()) == [1, 1] + result = await c.run_on_scheduler(lambda: 1) + assert result == 1 - test() + # But not with complex return values + with pytest.raises(TypeError): + await c.run(lambda: inc) + with pytest.raises(TypeError): + await c.run_on_scheduler(lambda: inc) @gen_cluster() @@ -6620,21 +6614,19 @@ async def test_annotations_task_state(c, s, a, b): @pytest.mark.parametrize("fn", ["compute", "persist"]) -def test_annotations_compute_time(fn): +@gen_cluster(client=True) +async def test_annotations_compute_time(c, s, a, b, fn): da = pytest.importorskip("dask.array") + x = da.ones(10, chunks=(5,)) - @gen_cluster(client=True) - async def test(c, s, a, b): - x = da.ones(10, chunks=(5,)) - - with dask.annotate(foo="bar"): - # Turn off optimization to avoid rewriting layers and picking up annotations - # that way. Instead, we want `compute`/`persist` to be able to pick them up. - x = await getattr(c, fn)(x, optimize_graph=False) - - assert all({"foo": "bar"} == ts.annotations for ts in s.tasks.values()) + with dask.annotate(foo="bar"): + # Turn off optimization to avoid rewriting layers and picking up annotations + # that way. Instead, we want `compute`/`persist` to be able to pick them up. + fut = getattr(c, fn)(x, optimize_graph=False) - test() + await wait(fut) + assert s.tasks + assert all(ts.annotations == {"foo": "bar"} for ts in s.tasks.values()) @pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7036") diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0f958120483..0bb2567e053 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -140,14 +140,16 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): nthreads=nthreads, config={"distributed.scheduler.work-stealing": False}, ) - async def test(c, s, *workers): + async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers): r""" - Ensure that sibling root tasks are scheduled to the same node, reducing future data transfer. + Ensure that sibling root tasks are scheduled to the same node, reducing future + data transfer. - We generate a wide layer of "root" tasks (random NumPy arrays). All of those tasks share 0-5 - trivial dependencies. The ``ndeps=0`` and ``ndeps=1`` cases are most common in real-world use - (``ndeps=1`` is basically ``da.from_array(..., inline_array=False)`` or ``da.from_zarr``). - The graph is structured like this (though the number of tasks and workers is different): + We generate a wide layer of "root" tasks (random NumPy arrays). All of those + tasks share 0-5 trivial dependencies. The ``ndeps=0`` and ``ndeps=1`` cases are + most common in real-world use (``ndeps=1`` is basically ``da.from_array(..., + inline_array=False)`` or ``da.from_zarr``). The graph is structured like this + (though the number of tasks and workers is different): |-W1-| |-W2-| |-W3-| |-W4-| < ---- ideal task scheduling @@ -159,9 +161,9 @@ async def test(c, s, *workers): \ \ \ | | / / / TRIVIAL * 0..5 - Neighboring `random-` tasks should be scheduled on the same worker. We test that generally, - only one worker holds each row of the array, that the `random-` tasks are never transferred, - and that there are few transfers overall. + Neighboring `random-` tasks should be scheduled on the same worker. We test that + generally, only one worker holds each row of the array, that the `random-` tasks + are never transferred, and that there are few transfers overall. """ da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") @@ -222,16 +224,18 @@ def random(**kwargs): keys = log["keys"] # The root-ish tasks should never be transferred assert not any(k.startswith("random") for k in keys), keys - # `object-` keys (the trivial deps of the root random tasks) should be transferred + # `object-` keys (the trivial deps of the root random tasks) should be + # transferred if any(not k.startswith("object") for k in keys): # But not many other things should be unexpected_transfers.append(list(keys)) - # A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact), - # but generally there should be very very few transfers. + # A transfer at the very end to move aggregated results is fine (necessary with + # unbalanced workers in fact), but generally there should be very very few + # transfers. assert len(unexpected_transfers) <= 3, unexpected_transfers - test() + test_decide_worker_coschedule_order_neighbors_() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 72bf63c1afb..7f997ed77a1 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -699,6 +699,7 @@ async def assert_balanced(inp, expected, c, s, *workers): raise Exception(f"Expected: {expected2}; got: {result2}") +@pytest.mark.slow @pytest.mark.parametrize( "inp,expected", [ @@ -732,19 +733,15 @@ async def assert_balanced(inp, expected, c, s, *workers): ], ) def test_balance(inp, expected): - async def test(*args, **kwargs): + async def test_balance_(*args, **kwargs): await assert_balanced(inp, expected, *args, **kwargs) - test = gen_cluster( - client=True, - nthreads=[("127.0.0.1", 1)] * len(inp), - config={ - "distributed.scheduler.default-task-durations": { - str(i): 1 for i in range(10) - } - }, - )(test) - test() + config = { + "distributed.scheduler.default-task-durations": {str(i): 1 for i in range(10)} + } + gen_cluster(client=True, nthreads=[("", 1)] * len(inp), config=config)( + test_balance_ + )() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=60) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 30019429ee7..cc628cbb401 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -678,7 +678,7 @@ def cluster( for worker in workers: worker["address"] = worker["queue"].get(timeout=5) except queue.Empty: - raise pytest.xfail.Exception("Worker failed to start in test") + pytest.xfail("Worker failed to start in test") saddr = scheduler_q.get() @@ -895,7 +895,7 @@ def gen_cluster( config: dict[str, Any] = {}, clean_kwargs: dict[str, Any] = {}, allow_unclosed: bool = False, - cluster_dump_directory: str | Literal[False] = "test_timeout_dump", + cluster_dump_directory: str | Literal[False] = "test_cluster_dump", ) -> Callable[[Callable], Callable]: from distributed import Client @@ -979,15 +979,16 @@ async def coro(): **client_kwargs, ) args = [c] + args + try: coro = func(*args, *outer_args, **kwargs) task = asyncio.create_task(coro) - coro2 = asyncio.wait_for(asyncio.shield(task), timeout) result = await coro2 if s.validate: s.validate_state() - except asyncio.TimeoutError as e: + + except asyncio.TimeoutError: assert task buffer = io.StringIO() # This stack indicates where the coro/test is suspended @@ -1004,9 +1005,31 @@ async def coro(): task.cancel() while not task.cancelled(): await asyncio.sleep(0.01) + + # Remove as much of the traceback as possible; it's + # uninteresting boilerplate from utils_test and asyncio and + # not from the code being tested. raise TimeoutError( - f"Test timeout after {timeout}s.\n{buffer.getvalue()}" - ) from e + f"Test timeout after {timeout}s.\n" + "========== Test stack trace starts here ==========\n" + f"{buffer.getvalue()}" + ) from None + + except pytest.xfail.Exception: + raise + + except Exception: + if cluster_dump_directory and not has_pytestmark( + test_func, "xfail" + ): + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) + raise + finally: if client and c.status not in ("closing", "closed"): await c._close(fast=s.status == Status.closed) @@ -1892,3 +1915,14 @@ def read(self, deserializers=None): def write(self, msg, serializers=None, on_error=None): raise OSError() + + +def has_pytestmark(test_func: Callable, name: str) -> bool: + """Return True if the test function is marked by the given @pytest.mark.; + False otherwise. + + FIXME doesn't work with individually marked parameters inside + @pytest.mark.parametrize + """ + marks = getattr(test_func, "pytestmark", []) + return any(mark.name == name for mark in marks) From b34171151d2d265ce747db19f66677a27bb31915 Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Fri, 21 Jan 2022 13:53:18 -0500 Subject: [PATCH 06/10] Update client.py docstrings (#5670) Adding the Returns numpy style documentation to the doc strings. This is for issue #3578. Co-authored-by: tharris72 Co-authored-by: Tim Harris --- distributed/client.py | 520 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 423 insertions(+), 97 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 006beb41921..99c7d1c557f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -156,6 +156,8 @@ class Future(WrappedKey): Client that should own this future. Defaults to _get_global_client() inform: bool Do we inform the scheduler that we need an update on this future + state: FutureState + The state of the future Examples -------- @@ -212,21 +214,55 @@ def __init__(self, key, client=None, inform=True, state=None): @property def executor(self): + """Returns the executor, which is the client. + + Returns + ------- + Client + The executor + """ return self.client @property def status(self): + """Returns the status + + Returns + ------- + str + The status + """ return self._state.status def done(self): - """Is the computation complete?""" + """Returns whether or not the computation completed. + + Returns + ------- + bool + True if the computation is complete, otherwise False + """ return self._state.done() def result(self, timeout=None): """Wait until computation completes, gather result to local process. - If *timeout* seconds are elapsed before returning, a - ``dask.distributed.TimeoutError`` is raised. + Parameters + ---------- + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + + Raises + ------ + dask.distributed.TimeoutError + If *timeout* seconds are elapsed before returning, a + ``dask.distributed.TimeoutError`` is raised. + + Returns + ------- + result : asyncio.Future + The Future that contains the result of the computation """ if self.client.asynchronous: return self.client.sync(self._result, callback_timeout=timeout) @@ -270,8 +306,20 @@ async def _exception(self): def exception(self, timeout=None, **kwargs): """Return the exception of a failed task - If *timeout* seconds are elapsed before returning, a - ``dask.distributed.TimeoutError`` is raised. + Parameters + ---------- + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + **kwargs : dict + Optional keyword arguments for the function + + Returns + ------- + Exception + The exception that was raised + If *timeout* seconds are elapsed before returning, a + ``dask.distributed.TimeoutError`` is raised. See Also -------- @@ -287,6 +335,11 @@ def add_done_callback(self, fn): errs, or is cancelled The callback is executed in a separate thread. + + Parameters + ---------- + fn : callable + The method or function to be called """ cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): @@ -309,7 +362,7 @@ def execute_callback(fut): ) def cancel(self, **kwargs): - """Cancel request to run this future + """Cancel the request to run this future See Also -------- @@ -327,7 +380,13 @@ def retry(self, **kwargs): return self.client.retry([self], **kwargs) def cancelled(self): - """Returns True if the future has been cancelled""" + """Returns True if the future has been cancelled + + Returns + ------- + bool + True if the future was 'cancelled', otherwise False + """ return self._state.status == "cancelled" async def _traceback(self): @@ -344,8 +403,13 @@ def traceback(self, timeout=None, **kwargs): ``traceback`` module. Alternatively if you call ``future.result()`` this traceback will accompany the raised exception. - If *timeout* seconds are elapsed before returning, a - ``dask.distributed.TimeoutError`` is raised. + Parameters + ---------- + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + If *timeout* seconds are elapsed before returning, a + ``dask.distributed.TimeoutError`` is raised. Examples -------- @@ -354,6 +418,11 @@ def traceback(self, timeout=None, **kwargs): >>> traceback.format_tb(tb) # doctest: +SKIP [...] + Returns + ------- + Future + The Future that contains the traceback + See Also -------- Future.exception @@ -362,11 +431,22 @@ def traceback(self, timeout=None, **kwargs): @property def type(self): + """Returns the type""" return self._state.type def release(self, _in_destructor=False): - # NOTE: this method can be called from different threads - # (see e.g. Client.get() or Future.__del__()) + """ + + Parameters + ---------- + _in_destructor: bool + Not used + + Notes + ----- + This method can be called from different threads + (see e.g. Client.get() or Future.__del__()) + """ if not self._cleared and self.client.generation == self._generation: self._cleared = True try: @@ -446,25 +526,47 @@ def _get_event(self): return event def cancel(self): + """Cancels the operation""" self.status = "cancelled" self.exception = CancelledError() self._get_event().set() def finish(self, type=None): + """Sets the status to 'finished' and sets the event + + Parameters + ---------- + type : any + The type + """ self.status = "finished" self._get_event().set() if type is not None: self.type = type def lose(self): + """Sets the status to 'lost' and clears the event""" self.status = "lost" self._get_event().clear() def retry(self): + """Sets the status to 'pending' and clears the event""" self.status = "pending" self._get_event().clear() def set_error(self, exception, traceback): + """Sets the error data + + Sets the status to 'error'. Sets the exception, the traceback, + and the event + + Parameters + ---------- + exception: Exception + The exception + traceback: Exception + The traceback + """ _, exception, traceback = clean_exception(exception, traceback) self.status = "error" @@ -473,14 +575,24 @@ def set_error(self, exception, traceback): self._get_event().set() def done(self): + """Returns 'True' if the event is not None and the event is set""" return self._event is not None and self._event.is_set() def reset(self): + """Sets the status to 'pending' and clears the event""" self.status = "pending" if self._event is not None: self._event.clear() async def wait(self, timeout=None): + """Wait for the awaitable to complete with a timeout. + + Parameters + ---------- + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + """ await asyncio.wait_for(self._get_event().wait(), timeout) def __repr__(self): @@ -488,7 +600,15 @@ def __repr__(self): async def done_callback(future, callback): - """Coroutine that waits on future, then calls callback""" + """Coroutine that waits on the future, then calls the callback + + Parameters + ---------- + future : asyncio.Future + The future + callback : callable + The callback + """ while future.status == "pending": await future._state.wait() callback(future) @@ -496,6 +616,13 @@ async def done_callback(future, callback): @partial(normalize_token.register, Future) def normalize_future(f): + """Returns the key and the type as a list + + Parameters + ---------- + list + The key and the type + """ return [f.key, type(f)] @@ -539,6 +666,8 @@ class Client(SyncMethodMixin): address: string, or Cluster This can be the address of a ``Scheduler`` server like a string ``'127.0.0.1:8786'`` or a cluster object like ``LocalCluster()`` + loop + The event loop timeout: int Timeout duration for initial connection to the scheduler set_as_default: bool (True) @@ -556,11 +685,20 @@ class Client(SyncMethodMixin): name: string (optional) Gives the client a name that will be included in logs generated on the scheduler for matters relating to this client + heartbeat_interval: int (optional) + Time in milliseconds between heartbeats to scheduler + serializers + The serializers to turn an object into a string + deserializers + The deserializers to turn the string into the original object + extensions : list + The extensions direct_to_workers: bool (optional) Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. - heartbeat_interval: int - Time in milliseconds between heartbeats to scheduler + connection_limit : int + The number of open comms to maintain at once in the connection pool + **kwargs: If you do not pass a scheduler address, Client will create a ``LocalCluster`` object, passing any extra keyword arguments. @@ -774,9 +912,9 @@ def __init__( @contextmanager def as_current(self): - """Thread-local, Task-local context manager that causes the Client.current class - method to return self. Any Future objects deserialized inside this context - manager will be automatically attached to this Client. + """Thread-local, Task-local context manager that causes the Client.current + class method to return self. Any Future objects deserialized inside this + context manager will be automatically attached to this Client. """ tok = _current_client.set(self) try: @@ -789,8 +927,23 @@ def current(cls, allow_global=True): """When running within the context of `as_client`, return the context-local current client. Otherwise, return the latest initialised Client. If no Client instances exist, raise ValueError. - If allow_global is set to False, raise ValueError if running outside of the - `as_client` context manager. + If allow_global is set to False, raise ValueError if running outside of + the `as_client` context manager. + + Parameters + ---------- + allow_global : bool + If True returns the default client + + Returns + ------- + Client + The current client + + Raises + ------ + ValueError + If there is no client set, a ValueError is raised """ out = _current_client.get() if out: @@ -1132,7 +1285,16 @@ async def _wait_for_workers(self, n_workers=0, timeout=None): info = await self.scheduler.identity() def wait_for_workers(self, n_workers=0, timeout=None): - """Blocking call to wait for n workers before continuing""" + """Blocking call to wait for n workers before continuing + + Parameters + ---------- + n_workers : int + The number of workers + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + """ return self.sync(self._wait_for_workers, n_workers, timeout=timeout) def _heartbeat(self): @@ -1361,6 +1523,13 @@ def close(self, timeout=no_default): If you started a client without arguments like ``Client()`` then this will also close the local cluster that was started at the same time. + + Parameters + ---------- + timeout : number + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + See Also -------- Client.restart @@ -1424,7 +1593,8 @@ def shutdown(self): def get_executor(self, **kwargs): """ - Return a concurrent.futures Executor for submitting tasks on this Client + Return a concurrent.futures Executor for submitting tasks on this + Client Parameters ---------- @@ -1434,8 +1604,9 @@ def get_executor(self, **kwargs): Returns ------- - An Executor object that's fully compatible with the concurrent.futures - API. + ClientExecutor + An Executor object that's fully compatible with the + concurrent.futures API. """ return ClientExecutor(self, **kwargs) @@ -1465,20 +1636,18 @@ def submit( coroutine, it will be run on the main event loop of a worker. Otherwise ``func`` will be run in a worker's task executor pool (see ``Worker.executors`` for more information.) - *args - **kwargs - pure : bool (defaults to True) - Whether or not the function is pure. Set ``pure=False`` for - impure functions like ``np.random.random``. - See :ref:`pure functions` for more details. + *args : tuple + Optional positional arguments + key : str + Unique identifier for the task. Defaults to function-name and hash workers : string or iterable of strings A set of worker addresses or hostnames on which computations may be performed. Leave empty to default to all workers (common case) - key : str - Unique identifier for the task. Defaults to function-name and hash - allow_other_workers : bool (defaults to False) - Used with ``workers``. Indicates whether or not the computations - may be performed on workers that are not in the `workers` set(s). + resources : dict (defaults to {}) + Defines the ``resources`` each instance of this mapped task + requires on the worker; e.g. ``{'GPU': 2}``. + See :doc:`worker resources ` for details on defining + resources. retries : int (default to 0) Number of allowed automatic retries if the task fails priority : Number @@ -1486,16 +1655,19 @@ def submit( Higher priorities take precedence fifo_timeout : str timedelta (default '100ms') Allowed amount of time between calls to consider the same priority - resources : dict (defaults to {}) - Defines the ``resources`` each instance of this mapped task requires - on the worker; e.g. ``{'GPU': 2}``. - See :doc:`worker resources ` for details on defining - resources. + allow_other_workers : bool (defaults to False) + Used with ``workers``. Indicates whether or not the computations + may be performed on workers that are not in the `workers` set(s). actor : bool (default False) Whether this task should exist on the worker as a stateful actor. See :doc:`actors` for additional details. actors : bool (default False) Alias for `actor` + pure : bool (defaults to True) + Whether or not the function is pure. Set ``pure=False`` for + impure functions like ``np.random.random``. + See :ref:`pure functions` for more details. + **kwargs Examples -------- @@ -1504,6 +1676,16 @@ def submit( Returns ------- Future + If running in asynchronous mode, returns the future. Otherwise + returns the concrete value + + Raises + ------ + TypeError + If 'func' is not callable, a TypeError is raised + ValueError + If 'allow_other_workers'is True and 'workers' is None, a + ValueError is raised See Also -------- @@ -1591,40 +1773,41 @@ def map( List-like objects to map over. They should have the same length. key : str, list Prefix for task names if string. Explicit names if list. - pure : bool (defaults to True) - Whether or not the function is pure. Set ``pure=False`` for - impure functions like ``np.random.random``. - See :ref:`pure functions` for more details. workers : string or iterable of strings A set of worker hostnames on which computations may be performed. Leave empty to default to all workers (common case) - allow_other_workers : bool (defaults to False) - Used with `workers`. Indicates whether or not the computations - may be performed on workers that are not in the `workers` set(s). retries : int (default to 0) Number of allowed automatic retries if a task fails - priority : Number - Optional prioritization of task. Zero is default. - Higher priorities take precedence - fifo_timeout : str timedelta (default '100ms') - Allowed amount of time between calls to consider the same priority resources : dict (defaults to {}) Defines the `resources` each instance of this mapped task requires on the worker; e.g. ``{'GPU': 2}``. See :doc:`worker resources ` for details on defining resources. + priority : Number + Optional prioritization of task. Zero is default. + Higher priorities take precedence + allow_other_workers : bool (defaults to False) + Used with `workers`. Indicates whether or not the computations + may be performed on workers that are not in the `workers` set(s). + fifo_timeout : str timedelta (default '100ms') + Allowed amount of time between calls to consider the same priority actor : bool (default False) Whether these tasks should exist on the worker as stateful actors. See :doc:`actors` for additional details. actors : bool (default False) Alias for `actor` + pure : bool (defaults to True) + Whether or not the function is pure. Set ``pure=False`` for + impure functions like ``np.random.random``. + See :ref:`pure functions` for more details. batch_size : int, optional - Submit tasks to the scheduler in batches of (at most) ``batch_size``. + Submit tasks to the scheduler in batches of (at most) + ``batch_size``. Larger batch sizes can be useful for very large ``iterables``, as the cluster can start processing tasks while later ones are submitted asynchronously. **kwargs : dict - Extra keywords to send to the function. + Extra keyword arguments to send to the function. Large values will be included explicitly in the task graph. Examples @@ -1910,6 +2093,8 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None): Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. This can also be set when creating the Client. + asynchronous: bool + If True the client is in asynchronous mode Returns ------- @@ -2085,7 +2270,8 @@ def scatter( Data to scatter out to workers. Output type matches input type. workers : list of tuples (optional) Optionally constrain locations of data. - Specify workers as hostname/port pairs, e.g. ``('127.0.0.1', 8787)``. + Specify workers as hostname/port pairs, e.g. + ``('127.0.0.1', 8787)``. broadcast : bool (defaults to False) Whether to send each data element to all workers. By default we round-robin based on number of cores. @@ -2096,6 +2282,11 @@ def scatter( hash : bool (optional) Whether or not to hash data to determine key. If False then this uses a random key + timeout : number, optional + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` + asynchronous: bool + If True the client is in asynchronous mode Returns ------- @@ -2177,7 +2368,10 @@ def cancel(self, futures, asynchronous=None, force=False): Parameters ---------- - futures : list of Futures + futures : List[Future] + The list of Futures + asynchronous: bool + If True the client is in asynchronous mode force : boolean (False) Cancel this future even if other clients desire it """ @@ -2197,6 +2391,9 @@ def retry(self, futures, asynchronous=None): Parameters ---------- futures : list of Futures + The list of Futures + asynchronous: bool + If True the client is in asynchronous mode """ return self.sync(self._retry, futures, asynchronous=asynchronous) @@ -2246,9 +2443,6 @@ def publish_dataset(self, *args, **kwargs): Parameters ---------- args : list of objects to publish as name - name : optional name of the dataset to publish - override : bool (optional, default False) - if true, override any already present dataset with the same name kwargs : dict named collections to publish on the scheduler @@ -2286,6 +2480,11 @@ def unpublish_dataset(self, name, **kwargs): """ Remove named datasets from scheduler + Parameters + ---------- + name : str + The name of the dataset to unpublish + Examples -------- >>> c.list_datasets() # doctest: +SKIP @@ -2329,11 +2528,18 @@ def get_dataset(self, name, default=NO_DEFAULT_PLACEHOLDER, **kwargs): Parameters ---------- - name : name of the dataset to retrieve - default : optional, not set by default - If set, do not raise a KeyError if the name is not present but return this default + name : str + name of the dataset to retrieve + default : str + optional, not set by default + If set, do not raise a KeyError if the name is not present but + return this default kwargs : dict - additional arguments to _get_dataset + additional keyword arguments to _get_dataset + + Returns + ------- + The dataset from the scheduler, if present See Also -------- @@ -2362,6 +2568,15 @@ def run_on_scheduler(self, function, *args, **kwargs): keyword argument ``dask_scheduler=``, which will be given the scheduler object itself. + Parameters + ---------- + function : callable + The function to run on the scheduler process + *args : tuple + Optional arguments for the function + **kwargs : dict + Optional keyword arguments for the function + Examples -------- >>> def get_number_of_tasks(dask_scheduler=None): @@ -2461,10 +2676,14 @@ def run( Parameters ---------- function : callable - *args : arguments for remote function - **kwargs : keyword arguments for remote function + The function to run + *args : tuple + Optional arguments for the remote function + **kwargs : dict + Optional keyword arguments for the remote function workers : list - Workers on which to run the function. Defaults to all known workers. + Workers on which to run the function. Defaults to all known + workers. wait : boolean (optional) If the function is asynchronous whether or not to wait until that function finishes. @@ -2540,13 +2759,10 @@ def run_coroutine(self, function, *args, **kwargs): function : a coroutine function (typically a function wrapped in gen.coroutine or a Python 3.5+ async function) - *args : arguments for remote function - **kwargs : keyword arguments for remote function - wait : boolean (default True) - Whether to wait for coroutines to end. - workers : list - Workers on which to run the function. Defaults to all known workers. - + *args : tuple + Optional arguments for the remote function + **kwargs : dict + Optional keyword arguments for the remote function """ return self.run(function, *args, **kwargs) @@ -2695,22 +2911,39 @@ def get( allow_other_workers : bool (defaults to False) Used with ``workers``. Indicates whether or not the computations may be performed on workers that are not in the `workers` set(s). - retries : int (default to 0) - Number of allowed automatic retries if computing a result fails - priority : Number - Optional prioritization of task. Zero is default. - Higher priorities take precedence resources : dict (defaults to {}) - Defines the ``resources`` each instance of this mapped task requires - on the worker; e.g. ``{'GPU': 2}``. + Defines the ``resources`` each instance of this mapped task + requires on the worker; e.g. ``{'GPU': 2}``. See :doc:`worker resources ` for details on defining resources. sync : bool (optional) Returns Futures if False or concrete values if True (default). + asynchronous: bool + If True the client is in asynchronous mode direct : bool Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. This can also be set when creating the Client. + retries : int (default to 0) + Number of allowed automatic retries if computing a result fails + priority : Number + Optional prioritization of task. Zero is default. + Higher priorities take precedence + fifo_timeout : timedelta str (defaults to '60s') + Allowed amount of time between calls to consider the same priority + actors : bool or dict (default None) + Whether these tasks should exist on the worker as stateful actors. + Specified on a global (True/False) or per-task (``{'x': True, + 'y': False}``) basis. See :doc:`actors` for additional details. + + + Returns + ------- + results + If 'sync' is True, returns the results. Otherwise, returns the + known data packed + If 'sync' is False, returns the known data. Otherwise, returns + the results Examples -------- @@ -2784,6 +3017,13 @@ def normalize_collection(self, collection): known futures within the scheduler. It returns a copy of the collection with a task graph that includes the overlapping futures. + Parameters + ---------- + collection + + Returns + ------- + Examples -------- >>> len(x.__dask_graph__()) # x is a dask collection with 100 tasks # doctest: +SKIP @@ -3118,6 +3358,8 @@ def upload_file(self, filename, **kwargs): ---------- filename : string Filename of .py, .egg or .zip file to send to workers + **kwargs : dict + Optional keyword arguments for the function Examples -------- @@ -3161,6 +3403,8 @@ def rebalance(self, futures=None, workers=None, **kwargs): A list of futures to balance, defaults all data workers : list, optional A list of workers on which to balance, defaults to all workers + **kwargs : dict + Optional keyword arguments for the function """ return self.sync(self._rebalance, futures, workers, **kwargs) @@ -3194,6 +3438,8 @@ def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs) Defaults to all. branching_factor : int, optional The number of workers that can copy data in each generation + **kwargs : dict + Optional keyword arguments for the remote function Examples -------- @@ -3225,6 +3471,8 @@ def nthreads(self, workers=None, **kwargs): workers : list (optional) A list of workers that we care about specifically. Leave empty to receive information about all workers. + **kwargs : dict + Optional keyword arguments for the remote function Examples -------- @@ -3256,6 +3504,8 @@ def who_has(self, futures=None, **kwargs): ---------- futures : list (optional) A list of futures, defaults to all data + **kwargs : dict + Optional keyword arguments for the remote function Examples -------- @@ -3296,6 +3546,8 @@ def has_what(self, workers=None, **kwargs): ---------- workers : list (optional) A list of worker addresses, defaults to all + **kwargs : dict + Optional keyword arguments for the remote function Examples -------- @@ -3366,6 +3618,8 @@ def nbytes(self, keys=None, summary=True, **kwargs): A list of keys, defaults to all keys summary : boolean, (optional) Summarize keys into key types + **kwargs : dict + Optional keyword arguments for the remote function Examples -------- @@ -3516,6 +3770,11 @@ async def _profile( def scheduler_info(self, **kwargs): """Basic information about the workers in the cluster + Parameters + ---------- + **kwargs : dict + Optional keyword arguments for the remote function + Examples -------- >>> c.scheduler_info() # doctest: +SKIP @@ -3733,9 +3992,9 @@ def get_worker_logs(self, n=None, workers=None, nanny=False): workers : iterable List of worker addresses to retrieve. Gets all workers by default. nanny : bool, default False - Whether to get the logs from the workers (False) or the nannies (True). If - specified, the addresses in `workers` should still be the worker addresses, - not the nanny addresses. + Whether to get the logs from the workers (False) or the nannies + (True). If specified, the addresses in `workers` should still be + the worker addresses, not the nanny addresses. Returns ------- @@ -3833,6 +4092,13 @@ def retire_workers(self, workers=None, close_workers=True, **kwargs): See dask.distributed.Scheduler.retire_workers for the full docstring. + Parameters + ---------- + workers + close_workers + **kwargs : dict + Optional keyword arguments for the remote function + Examples -------- You can get information about active workers using the following: @@ -3939,10 +4205,25 @@ async def _get_versions(self, check=False, packages=[]): return result def futures_of(self, futures): + """Wrapper method of futures_of + + Parameters + ---------- + futures : tuple + The futures + """ return futures_of(futures, client=self) def start_ipython(self, *args, **kwargs): - """Deprecated - Method moved to start_ipython_workers""" + """Deprecated - Method moved to start_ipython_workers + + Parameters + ---------- + *args : tuple + Optional arguments for the function + **kwargs : dict + Optional keyword arguments for the function + """ raise Exception("Method moved to start_ipython_workers") async def _start_ipython_workers(self, workers): @@ -4293,9 +4574,9 @@ def register_worker_plugin(self, plugin=None, name=None, nanny=None, **kwargs): Registers a lifecycle worker plugin for all current and future workers. This registers a new object to handle setup, task state transitions and - teardown for workers in this cluster. The plugin will instantiate itself - on all currently connected workers. It will also be run on any worker - that connects in the future. + teardown for workers in this cluster. The plugin will instantiate + itself on all currently connected workers. It will also be run on any + worker that connects in the future. The plugin may include methods ``setup``, ``teardown``, ``transition``, and ``release_key``. See the @@ -4333,7 +4614,8 @@ class will be instantiated with any extra keyword arguments. ... pass ... def teardown(self, worker: dask.distributed.Worker): ... pass - ... def transition(self, key: str, start: str, finish: str, **kwargs): + ... def transition(self, key: str, start: str, finish: str, + ... **kwargs): ... pass ... def release_key(self, key: str, state: str, cause: str | None, reason: None, report: bool): ... pass @@ -4493,9 +4775,10 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): Parameters ---------- - fs : list of futures + fs : List[Future] timeout : number, optional - Time in seconds after which to raise a ``dask.distributed.TimeoutError`` + Time in seconds after which to raise a + ``dask.distributed.TimeoutError`` return_when : str, optional One of `ALL_COMPLETED` or `FIRST_COMPLETED` @@ -4556,8 +4839,8 @@ class as_completed: Whether to wait and include results of futures as well; in this case `as_completed` yields a tuple of (future, result) raise_errors: bool (True) - Whether we should raise when the result of a future raises an exception; - only affects behavior when `with_results=True`. + Whether we should raise when the result of a future raises an + exception; only affects behavior when `with_results=True`. Examples -------- @@ -4787,7 +5070,18 @@ def AsCompleted(*args, **kwargs): def default_client(c=None): - """Return a client if one has started""" + """Return a client if one has started + + Parameters + ---------- + c : Client + The client + + Returns + ------- + c : Client + The client, if one has started + """ c = c or _get_global_client() if c: return c @@ -4801,12 +5095,36 @@ def default_client(c=None): def ensure_default_client(client): - """Ensures the client passed as argument is set as the default""" + """Ensures the client passed as argument is set as the default + + Parameters + ---------- + client : Client + The client + """ dask.config.set(scheduler="dask.distributed") _set_global_client(client) def redict_collection(c, dsk): + """Change the dictionary in the collection + + Parameters + ---------- + c : collection + The collection + dsk : dict + The dictionary + + Returns + ------- + c : Delayed + If the collection is a 'Delayed' object the collection is returned + cc : collection + If the collection is not a 'Delayed' object a copy of the + collection with xthe new dictionary is returned + + """ from dask.delayed import Delayed if isinstance(c, Delayed): @@ -4824,6 +5142,8 @@ def futures_of(o, client=None): ---------- o : collection A possibly nested collection of Dask objects + client : Client, optional + The client Examples -------- @@ -4831,6 +5151,11 @@ def futures_of(o, client=None): [, ] + Raises + ------ + CancelledError + If one of the futures is cancelled a CancelledError is raised + Returns ------- futures : List[Future] @@ -5084,12 +5409,13 @@ def temp_default_client(c): """Set the default client for the duration of the context .. note:: - This function should be used exclusively for unit testing the default client - functionality. In all other cases, please use ``Client.as_current`` instead. + This function should be used exclusively for unit testing the default + client functionality. In all other cases, please use + ``Client.as_current`` instead. .. note:: - Unlike ``Client.as_current``, this context manager is neither thread-local nor - task-local. + Unlike ``Client.as_current``, this context manager is neither + thread-local nor task-local. Parameters ---------- From 337f152a0eccf7c1511c0462f007dda6d507e4d3 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Jan 2022 20:41:23 +0000 Subject: [PATCH 07/10] Fix flaky `test_dump_cluster_unresponsive_remote_worker` (#5679) --- distributed/tests/test_utils_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 80632b27c9f..67d24ce4dbc 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -10,6 +10,8 @@ import yaml from tornado import gen +import dask.config + from distributed import Client, Nanny, Scheduler, Worker, config, default_client from distributed.core import Server, rpc from distributed.metrics import time @@ -100,11 +102,8 @@ async def test_gen_cluster_parametrized_variadic_workers(c, s, *workers, foo): ) async def test_gen_cluster_set_config_nanny(c, s, a, b): def assert_config(): - import dask - assert dask.config.get("distributed.comm.timeouts.connect") == "1s" assert dask.config.get("new.config.value") == "foo" - return dask.config await c.run(assert_config) await c.run_on_scheduler(assert_config) @@ -535,12 +534,11 @@ async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmpdir): @gen_cluster( client=True, Worker=Nanny, - config={"distributed.comm.timeouts.connect": "200ms"}, + config={"distributed.comm.timeouts.connect": "600ms"}, ) async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir): - addr1, addr2 = s.workers clog_fut = asyncio.create_task( - c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[addr1]) + c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[a.worker_address]) ) await asyncio.sleep(0.2) @@ -549,7 +547,9 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmpdir): out = yaml.safe_load(fh) assert out.keys() == {"scheduler", "workers", "versions"} - assert isinstance(out["workers"][addr2], dict) - assert out["workers"][addr1].startswith("OSError('Timed out trying to connect to") + assert isinstance(out["workers"][b.worker_address], dict) + assert out["workers"][a.worker_address].startswith( + "OSError('Timed out trying to connect to" + ) clog_fut.cancel() From 80320467690e9f24e92693392b39d4dbd2ac06ec Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 24 Jan 2022 05:08:03 -0700 Subject: [PATCH 08/10] P2P shuffle skeleton (#5520) --- distributed/shuffle/__init__.py | 18 + distributed/shuffle/shuffle.py | 111 ++++++ distributed/shuffle/shuffle_extension.py | 329 ++++++++++++++++++ distributed/shuffle/tests/__init__.py | 0 distributed/shuffle/tests/test_graph.py | 88 +++++ .../shuffle/tests/test_shuffle_extension.py | 283 +++++++++++++++ distributed/worker.py | 4 +- 7 files changed, 832 insertions(+), 1 deletion(-) create mode 100644 distributed/shuffle/__init__.py create mode 100644 distributed/shuffle/shuffle.py create mode 100644 distributed/shuffle/shuffle_extension.py create mode 100644 distributed/shuffle/tests/__init__.py create mode 100644 distributed/shuffle/tests/test_graph.py create mode 100644 distributed/shuffle/tests/test_shuffle_extension.py diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py new file mode 100644 index 00000000000..a431c5bddfd --- /dev/null +++ b/distributed/shuffle/__init__.py @@ -0,0 +1,18 @@ +try: + import pandas +except ImportError: + SHUFFLE_AVAILABLE = False +else: + del pandas + SHUFFLE_AVAILABLE = True + + from .shuffle import rearrange_by_column_p2p + from .shuffle_extension import ShuffleId, ShuffleMetadata, ShuffleWorkerExtension + +__all__ = [ + "SHUFFLE_AVAILABLE", + "rearrange_by_column_p2p", + "ShuffleId", + "ShuffleMetadata", + "ShuffleWorkerExtension", +] diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/shuffle.py new file mode 100644 index 00000000000..e30ddde746c --- /dev/null +++ b/distributed/shuffle/shuffle.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dask.base import tokenize +from dask.dataframe import DataFrame +from dask.delayed import Delayed, delayed +from dask.highlevelgraph import HighLevelGraph + +from .shuffle_extension import NewShuffleMetadata, ShuffleId, ShuffleWorkerExtension + +if TYPE_CHECKING: + import pandas as pd + + +def get_ext() -> ShuffleWorkerExtension: + from distributed import get_worker + + try: + worker = get_worker() + except ValueError as e: + raise RuntimeError( + "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " + "please confirm that you've created a distributed Client and are submitting this computation through it." + ) from e + extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle") + if not extension: + raise RuntimeError( + f"The worker {worker.address} does not have a ShuffleExtension. " + "Is pandas installed on the worker?" + ) + return extension + + +def shuffle_setup(metadata: NewShuffleMetadata) -> None: + get_ext().create_shuffle(metadata) + + +def shuffle_transfer(input: pd.DataFrame, id: ShuffleId, setup=None) -> None: + get_ext().add_partition(input, id) + + +def shuffle_unpack(id: ShuffleId, output_partition: int, barrier=None) -> pd.DataFrame: + return get_ext().get_output_partition(id, output_partition) + + +def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: + get_ext().barrier(id) + + +def rearrange_by_column_p2p( + df: DataFrame, + column: str, + npartitions: int | None = None, +): + npartitions = npartitions or df.npartitions + token = tokenize(df, column, npartitions) + + setup = delayed(shuffle_setup, pure=True)( + NewShuffleMetadata( + ShuffleId(token), + df._meta, + column, + npartitions, + ) + ) + + transferred = df.map_partitions( + shuffle_transfer, + token, + setup, + meta=df, + enforce_metadata=False, + transform_divisions=False, + ) + + barrier_key = "shuffle-barrier-" + token + barrier_dsk = {barrier_key: (shuffle_barrier, token, transferred.__dask_keys__())} + barrier = Delayed( + barrier_key, + HighLevelGraph.from_collections( + barrier_key, barrier_dsk, dependencies=[transferred] + ), + ) + + name = "shuffle-unpack-" + token + dsk = { + (name, i): (shuffle_unpack, token, i, barrier_key) for i in range(npartitions) + } + # TODO use this blockwise (https://github.com/coiled/oss-engineering/issues/49) + # Changes task names, so breaks setting worker restrictions at the moment. + # Also maybe would be nice if the `DataFrameIOLayer` interface supported this? + # dsk = blockwise( + # shuffle_unpack, + # name, + # "i", + # token, + # None, + # BlockwiseDepDict({(i,): i for i in range(npartitions)}), + # "i", + # barrier_key, + # None, + # numblocks={}, + # ) + + return DataFrame( + HighLevelGraph.from_collections(name, dsk, [barrier]), + name, + df._meta, + [None] * (npartitions + 1), + ) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py new file mode 100644 index 00000000000..e5e0baaf7bc --- /dev/null +++ b/distributed/shuffle/shuffle_extension.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import asyncio +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, NewType + +import pandas as pd + +from distributed.protocol import to_serialize +from distributed.utils import sync + +if TYPE_CHECKING: + from distributed.worker import Worker + +ShuffleId = NewType("ShuffleId", str) + + +# NOTE: we use these dataclasses primarily for type-checking benefits. +# They take the place of positional arguments to `shuffle_init`, +# which the type-checker can't validate when it's called as an RPC. + + +@dataclass(frozen=True, eq=False) +class NewShuffleMetadata: + "Metadata to create a shuffle" + id: ShuffleId + empty: pd.DataFrame + column: str + npartitions: int + + +@dataclass(frozen=True, eq=False) +class ShuffleMetadata(NewShuffleMetadata): + """ + Metadata every worker needs to share about a shuffle. + + A `ShuffleMetadata` is created with a task and sent to all workers + over the `ShuffleWorkerExtension.shuffle_init` RPC. + """ + + workers: list[str] + + def worker_for(self, output_partition: int) -> str: + "Get the address of the worker which should hold this output partition number" + assert output_partition >= 0, f"Negative output partition: {output_partition}" + if output_partition >= self.npartitions: + raise IndexError( + f"Output partition {output_partition} does not exist in a shuffle producing {self.npartitions} partitions" + ) + i = len(self.workers) * output_partition // self.npartitions + return self.workers[i] + + def _partition_range(self, worker: str) -> tuple[int, int]: + "Get the output partition numbers (inclusive) that a worker will hold" + i = self.workers.index(worker) + first = math.ceil(self.npartitions * i / len(self.workers)) + last = math.ceil(self.npartitions * (i + 1) / len(self.workers)) - 1 + return first, last + + def npartitions_for(self, worker: str) -> int: + "Get the number of output partitions a worker will hold" + first, last = self._partition_range(worker) + return last - first + 1 + + +class Shuffle: + "State for a single active shuffle" + + def __init__(self, metadata: ShuffleMetadata, worker: Worker) -> None: + self.metadata = metadata + self.worker = worker + self.output_partitions: defaultdict[int, list[pd.DataFrame]] = defaultdict(list) + self.output_partitions_left = metadata.npartitions_for(worker.address) + self.transferred = False + + def receive(self, output_partition: int, data: pd.DataFrame) -> None: + assert not self.transferred, "`receive` called after barrier task" + self.output_partitions[output_partition].append(data) + + async def add_partition(self, data: pd.DataFrame) -> None: + assert not self.transferred, "`add_partition` called after barrier task" + tasks = [] + # NOTE: `groupby` blocks the event loop, but it also holds the GIL, + # so we don't bother offloading to a thread. See bpo-7946. + for output_partition, data in data.groupby(self.metadata.column): + # NOTE: `column` must refer to an integer column, which is the output partition number for the row. + # This is always `_partitions`, added by `dask/dataframe/shuffle.py::shuffle`. + addr = self.metadata.worker_for(int(output_partition)) + task = asyncio.create_task( + self.worker.rpc(addr).shuffle_receive( + shuffle_id=self.metadata.id, + output_partition=output_partition, + data=to_serialize(data), + ) + ) + tasks.append(task) + + # TODO Once RerunGroup logic exists (https://github.com/dask/distributed/issues/5403), + # handle errors and cancellation here in a way that lets other workers cancel & clean up their shuffles. + # Without it, letting errors kill the task is all we can do. + await asyncio.gather(*tasks) + + def get_output_partition(self, i: int) -> pd.DataFrame: + assert self.transferred, "`get_output_partition` called before barrier task" + + assert self.metadata.worker_for(i) == self.worker.address, ( + f"Output partition {i} belongs on {self.metadata.worker_for(i)}, " + f"not {self.worker.address}. {self.metadata!r}" + ) + # ^ NOTE: this check isn't necessary, just a nice validation to prevent incorrect + # data in the case something has gone very wrong + + assert ( + self.output_partitions_left > 0 + ), f"No outputs remaining, but requested output partition {i} on {self.worker.address}." + self.output_partitions_left -= 1 + + try: + parts = self.output_partitions.pop(i) + except KeyError: + return self.metadata.empty + + assert parts, f"Empty entry for output partition {i}" + return pd.concat(parts, copy=False) + + def inputs_done(self) -> None: + assert not self.transferred, "`inputs_done` called multiple times" + self.transferred = True + + def done(self) -> bool: + return self.transferred and self.output_partitions_left == 0 + + +class ShuffleWorkerExtension: + "Extend the Worker with routes and state for peer-to-peer shuffles" + + def __init__(self, worker: Worker) -> None: + # Attach to worker + worker.handlers["shuffle_receive"] = self.shuffle_receive + worker.handlers["shuffle_init"] = self.shuffle_init + worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done + worker.extensions["shuffle"] = self + + # Initialize + self.worker: Worker = worker + self.shuffles: dict[ShuffleId, Shuffle] = {} + + # Handlers + ########## + # NOTE: handlers are not threadsafe, but they're called from async comms, so that's okay + + def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: + """ + Hander: Register a new shuffle that is about to begin. + Using a shuffle with an already-known ID is an error. + """ + if metadata.id in self.shuffles: + raise ValueError( + f"Shuffle {metadata.id!r} is already registered on worker {self.worker.address}" + ) + self.shuffles[metadata.id] = Shuffle(metadata, self.worker) + + def shuffle_receive( + self, + comm: object, + shuffle_id: ShuffleId, + output_partition: int, + data: pd.DataFrame, + ) -> None: + """ + Hander: Receive an incoming shard of data from a peer worker. + Using an unknown ``shuffle_id`` is an error. + """ + self._get_shuffle(shuffle_id).receive(output_partition, data) + + def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: + """ + Hander: Inform the extension that all input partitions have been handed off to extensions. + Using an unknown ``shuffle_id`` is an error. + """ + shuffle = self._get_shuffle(shuffle_id) + shuffle.inputs_done() + if shuffle.done(): + # If the shuffle has no output partitions, remove it now; + # `get_output_partition` will never be called. + # This happens when there are fewer output partitions than workers. + del self.shuffles[shuffle_id] + + # Tasks + ####### + + def create_shuffle(self, new_metadata: NewShuffleMetadata) -> ShuffleMetadata: + return sync(self.worker.loop, self._create_shuffle, new_metadata) # type: ignore + + async def _create_shuffle( + self, new_metadata: NewShuffleMetadata + ) -> ShuffleMetadata: + """ + Task: Create a new shuffle and broadcast it to all workers. + """ + # TODO would be nice to not have to have the RPC in this method, and have shuffles started implicitly + # by the first `receive`/`add_partition`. To do that, shuffle metadata would be passed into + # every task, and from there into the extension (rather than stored within a `Shuffle`), + # However: + # 1. It makes scheduling much harder, since it's a widely-shared common dep + # (https://github.com/dask/distributed/pull/5325) + # 2. Passing in metadata everywhere feels contrived when it would be so easy to store + # 3. The metadata may not be _that_ small (1000s of columns + 1000s of workers); + # serializing and transferring it repeatedly adds overhead. + if new_metadata.id in self.shuffles: + raise ValueError( + f"Shuffle {new_metadata.id!r} is already registered on worker {self.worker.address}" + ) + + identity = await self.worker.scheduler.identity() + + workers = list(identity["workers"]) + metadata = ShuffleMetadata( + new_metadata.id, + new_metadata.empty, + new_metadata.column, + new_metadata.npartitions, + workers, + ) + + # Start the shuffle on all peers + # Note that this will call `shuffle_init` on our own worker as well + await asyncio.gather( + *( + self.worker.rpc(addr).shuffle_init(metadata=to_serialize(metadata)) + for addr in metadata.workers + ), + ) + # TODO handle errors from peers, and cancellation. + # If any peers can't start the shuffle, tell successful peers to cancel it. + + return metadata # NOTE: unused in tasks, just handy for tests + + def add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: + sync(self.worker.loop, self._add_partition, data, shuffle_id) + + async def _add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: + """ + Task: Hand off an input partition to the ShuffleExtension. + + This will block until the extension is ready to receive another input partition. + + Using an unknown ``shuffle_id`` is an error. + """ + await self._get_shuffle(shuffle_id).add_partition(data) + + def barrier(self, shuffle_id: ShuffleId) -> None: + sync(self.worker.loop, self._barrier, shuffle_id) + + async def _barrier(self, shuffle_id: ShuffleId) -> None: + """ + Task: Note that the barrier task has been reached (`add_partition` called for all input partitions) + + Using an unknown ``shuffle_id`` is an error. Calling this before all partitions have been + added is undefined. + """ + # NOTE: in this basic shuffle implementation, doing things during the barrier + # is mostly unnecessary. We only need it to inform workers that don't receive + # any output partitions that they can clean up. + # (Otherwise, they'd have no way to know if they needed to keep the `Shuffle` around + # for more input partitions, which might come at some point. Workers that _do_ receive + # output partitions could infer this, since once `get_output_partition` gets called the + # first time, they can assume there are no more inputs.) + # + # Technically right now, we could call the `shuffle_inputs_done` RPC only on workers + # where `metadata.npartitions_for(worker) == 0`. + # However, when we have buffering, this barrier step will become important for + # all workers, since they'll use it to flush their buffers and send any leftover shards + # to their peers. + + metadata = self._get_shuffle(shuffle_id).metadata + + # Set worker restrictions for unpack tasks + + # Could do this during `create_shuffle`, but we might as well overlap it with the time + # workers will be flushing buffers to each other. + name = "shuffle-unpack-" + metadata.id # TODO single-source task name + + # FIXME TODO XXX what about when culling means not all of the output tasks actually exist??! + # - these restrictions are invalid + # - get_output_partition won't be called enough times, so cleanup won't happen + # - also, we're transferring data we don't need to transfer + restrictions = { + f"('{name}', {i})": [metadata.worker_for(i)] + for i in range(metadata.npartitions) + } + + # Tell all peers that we've reached the barrier + + # Note that this will call `shuffle_inputs_done` on our own worker as well + await asyncio.gather( + *( + self.worker.rpc(worker).shuffle_inputs_done(shuffle_id=shuffle_id) + for worker in metadata.workers + ), + self.worker.scheduler.set_restrictions(worker=restrictions), + ) + # TODO handle errors from workers and scheduler, and cancellation. + + def get_output_partition( + self, shuffle_id: ShuffleId, output_partition: int + ) -> pd.DataFrame: + """ + Task: Retrieve a shuffled output partition from the ShuffleExtension. + + Calling this for a ``shuffle_id`` which is unknown or incomplete is an error. + """ + shuffle = self._get_shuffle(shuffle_id) + output = shuffle.get_output_partition(output_partition) + if shuffle.done(): + # key missing if another thread got to it first + self.shuffles.pop(shuffle_id, None) + return output + + def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: + "Get a shuffle by ID; raise ValueError if it's not registered." + try: + return self.shuffles[shuffle_id] + except KeyError: + raise ValueError( + f"Shuffle {shuffle_id!r} is not registered on worker {self.worker.address}" + ) from None diff --git a/distributed/shuffle/tests/__init__.py b/distributed/shuffle/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py new file mode 100644 index 00000000000..256f9132cd4 --- /dev/null +++ b/distributed/shuffle/tests/test_graph.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import dask +import dask.dataframe as dd +from dask.blockwise import Blockwise +from dask.dataframe.shuffle import partitioning_index, rearrange_by_column_tasks +from dask.utils_test import hlg_layer_topological + +from distributed.utils_test import gen_cluster + +from ..shuffle import rearrange_by_column_p2p +from ..shuffle_extension import ShuffleWorkerExtension + +if TYPE_CHECKING: + from distributed import Client, Scheduler, Worker + + +def shuffle( + df: dd.DataFrame, on: str, rearrange=rearrange_by_column_p2p +) -> dd.DataFrame: + "Simple version of `DataFrame.shuffle`, so we don't need dask to know about 'p2p'" + return ( + df.assign( + partition=lambda df: df[on].map_partitions( + partitioning_index, df.npartitions, transform_divisions=False + ) + ) + .pipe(rearrange, "partition") + .drop("partition", axis=1) + ) + + +def test_shuffle_helper(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffle_helper = shuffle(df, "id", rearrange=rearrange_by_column_tasks) + dask_shuffle = df.shuffle("id", shuffle="tasks") + dd.utils.assert_eq(shuffle_helper, dask_shuffle, scheduler=client) + + +def test_basic(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffled = shuffle(df, "id") + + (opt,) = dask.optimize(shuffled) + assert isinstance(hlg_layer_topological(opt.dask, 1), Blockwise) + # setup -> blockwise -> barrier -> unpack -> drop_by_shallow_copy + assert len(opt.dask.layers) == 5 + + dd.utils.assert_eq(shuffled, df.shuffle("id", shuffle="tasks"), scheduler=client) + # ^ NOTE: this works because `assert_eq` sorts the rows before comparing + + +@gen_cluster([("", 2)] * 4, client=True) +async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + shuffled = shuffle(df, "id") + + exts: list[ShuffleWorkerExtension] = [w.extensions["shuffle"] for w in workers] + for ext in exts: + assert not ext.shuffles + + f = c.compute(shuffled) + # TODO this is a bad/pointless test. the `f.done()` is necessary in case the shuffle is really fast. + # To test state more thoroughly, we'd need a way to 'stop the world' at various stages. Like have the + # scheduler pause everything when the barrier is reached. Not sure yet how to implement that. + while not all(len(ext.shuffles) == 1 for ext in exts) and not f.done(): + await asyncio.sleep(0.1) + + await f + assert all(not ext.shuffles for ext in exts) + + +def test_multiple_linear(client: Client): + df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + s1 = shuffle(df, "id") + s1["x"] = s1["x"] + 1 + s2 = shuffle(s1, "x") + + # TODO eventually test for fusion between s1's unpacks, the `+1`, and s2's `transfer`s + + dd.utils.assert_eq( + s2, + df.assign(x=lambda df: df.x + 1).shuffle("x", shuffle="tasks"), + scheduler=client, + ) diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py new file mode 100644 index 00000000000..52abf8cd37d --- /dev/null +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +import string +from collections import Counter +from typing import TYPE_CHECKING + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from distributed.utils_test import gen_cluster + +from ..shuffle_extension import ( + NewShuffleMetadata, + ShuffleId, + ShuffleMetadata, + ShuffleWorkerExtension, +) + +if TYPE_CHECKING: + from distributed import Client, Future, Scheduler, Worker + + +@pytest.mark.parametrize("npartitions", [1, 2, 3, 5]) +@pytest.mark.parametrize("n_workers", [1, 2, 3, 5]) +def test_worker_for_distribution(npartitions: int, n_workers: int): + "Test that `worker_for` distributes evenly" + metadata = ShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": []}), + "A", + npartitions, + list(string.ascii_lowercase[:n_workers]), + ) + + with pytest.raises(AssertionError, match="Negative"): + metadata.worker_for(-1) + + assignments = [metadata.worker_for(i) for i in range(metadata.npartitions)] + + # Test internal `_partition_range` method + for w in metadata.workers: + first, last = metadata._partition_range(w) + assert all( + [ + first <= p_i <= last if a == w else p_i < first or p_i > last + for p_i, a in enumerate(assignments) + ] + ) + + counter = Counter(assignments) + assert len(counter) == min(npartitions, n_workers) + + # Test `npartitions_for` + calculated_counter = {w: metadata.npartitions_for(w) for w in metadata.workers} + assert counter == { + w: count for w, count in calculated_counter.items() if count != 0 + } + assert calculated_counter.keys() == set(metadata.workers) + # ^ this also checks that workers receiving 0 output partitions were calculated properly + + # Test the distribution of worker assignments. + # All workers should be assigned the same number of partitions, or if + # there's an odd number, some workers will be assigned only one extra partition. + counts = set(counter.values()) + assert len(counts) <= 2 + if len(counts) == 2: + lo, hi = sorted(counts) + assert lo == hi - 1 + + with pytest.raises(IndexError, match="does not exist"): + metadata.worker_for(npartitions) + + +@gen_cluster([("", 1)]) +async def test_installation(s: Scheduler, worker: Worker): + ext = worker.extensions["shuffle"] + assert isinstance(ext, ShuffleWorkerExtension) + assert worker.handlers["shuffle_receive"] == ext.shuffle_receive + assert worker.handlers["shuffle_init"] == ext.shuffle_init + assert worker.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done + + +@gen_cluster([("", 1)]) +async def test_init(s: Scheduler, worker: Worker): + ext: ShuffleWorkerExtension = worker.extensions["shuffle"] + assert not ext.shuffles + metadata = ShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": []}), + "A", + 5, + [worker.address], + ) + + ext.shuffle_init(None, metadata) + assert list(ext.shuffles) == [metadata.id] + + with pytest.raises(ValueError, match="already registered"): + ext.shuffle_init(None, metadata) + + assert list(ext.shuffles) == [metadata.id] + + +async def add_dummy_unpack_keys( + new_metadata: NewShuffleMetadata, client: Client +) -> dict[str, Future]: + """ + Add dummy keys to the scheduler, so setting worker restrictions during `barrier` succeeds. + + Note: you must hang onto the Futures returned by this function, so they don't get released prematurely. + """ + # NOTE: `scatter` is just used as an easy way to create keys on the scheduler that won't actually + # be scheduled. It would be reasonable if this stops working in the future, if some validation is + # added preventing worker restrictions on scattered data (since it makes no sense). + fs = await client.scatter( + { + str(("shuffle-unpack-" + new_metadata.id, i)): None + for i in range(new_metadata.npartitions) + } + ) # type: ignore + await asyncio.gather(*fs.values()) + return fs + + +@gen_cluster([("", 1)] * 4) +async def test_create(s: Scheduler, *workers: Worker): + exts: list[ShuffleWorkerExtension] = [w.extensions["shuffle"] for w in workers] + + new_metadata = NewShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": []}), + "A", + 5, + ) + + metadata = await exts[0]._create_shuffle(new_metadata) + + # Check shuffle was created on all workers + for ext in exts: + assert len(ext.shuffles) == 1 + shuffle = ext.shuffles[new_metadata.id] + assert sorted(shuffle.metadata.workers) == sorted(w.address for w in workers) + + # TODO (resilience stage) what happens if some workers already have + # the ID registered, but others don't? + + with pytest.raises(ValueError, match="already registered"): + await exts[0]._create_shuffle(new_metadata) + + +@gen_cluster([("", 1)] * 4) +async def test_add_partition(s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + new_metadata = NewShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": [], "partition": []}), + "partition", + 8, + ) + + ext = next(iter(exts.values())) + metadata = await ext._create_shuffle(new_metadata) + partition = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + "partition": [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + await ext._add_partition(partition, new_metadata.id) + + with pytest.raises(ValueError, match="not registered"): + await ext._add_partition(partition, ShuffleId("bar")) + + for i, data in partition.groupby(new_metadata.column): + addr = metadata.worker_for(int(i)) + ext = exts[addr] + received = ext.shuffles[metadata.id].output_partitions[int(i)] + assert len(received) == 1 + assert_frame_equal(data, received[0]) + + # TODO (resilience stage) test failed sends + + +@gen_cluster([("", 1)] * 4, client=True) +async def test_barrier(c: Client, s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + new_metadata = NewShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": [], "partition": []}), + "partition", + 4, + ) + fs = await add_dummy_unpack_keys(new_metadata, c) + + ext = next(iter(exts.values())) + metadata = await ext._create_shuffle(new_metadata) + partition = pd.DataFrame( + { + "A": ["a", "b", "c"], + "partition": [0, 1, 2], + } + ) + await ext._add_partition(partition, metadata.id) + + await ext._barrier(metadata.id) + + # Check scheduler restrictions were set for unpack tasks + for key, i in zip(fs, range(metadata.npartitions)): + assert s.tasks[key].worker_restrictions == {metadata.worker_for(i)} + + # Check all workers have been informed of the barrier + for addr, ext in exts.items(): + if metadata.npartitions_for(addr): + shuffle = ext.shuffles[metadata.id] + assert shuffle.transferred + assert not shuffle.done() + else: + # No output partitions on this worker; shuffle already cleaned up + assert not ext.shuffles + + +@gen_cluster([("", 1)] * 4, client=True) +async def test_get_partition(c: Client, s: Scheduler, *workers: Worker): + exts: dict[str, ShuffleWorkerExtension] = { + w.address: w.extensions["shuffle"] for w in workers + } + + new_metadata = NewShuffleMetadata( + ShuffleId("foo"), + pd.DataFrame({"A": [], "partition": []}), + "partition", + 8, + ) + _ = await add_dummy_unpack_keys(new_metadata, c) + + ext = next(iter(exts.values())) + metadata = await ext._create_shuffle(new_metadata) + p1 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + "partition": [0, 1, 2, 3, 4, 5, 6, 6], + } + ) + p2 = pd.DataFrame( + { + "A": ["a", "b", "c", "d", "e", "f", "g", "h"], + "partition": [0, 1, 2, 3, 0, 0, 2, 3], + } + ) + await asyncio.gather( + ext._add_partition(p1, metadata.id), ext._add_partition(p2, metadata.id) + ) + await ext._barrier(metadata.id) + + with pytest.raises(AssertionError, match="belongs on"): + ext.get_output_partition(metadata.id, 7) + + full = pd.concat([p1, p2]) + expected_groups = full.groupby("partition") + for output_i in range(metadata.npartitions): + addr = metadata.worker_for(output_i) + ext = exts[addr] + result = ext.get_output_partition(metadata.id, output_i) + try: + expected = expected_groups.get_group(output_i) + except KeyError: + expected = metadata.empty + assert_frame_equal(expected, result) + + # Once all partitions are retrieved, shuffles are cleaned up + for ext in exts.values(): + assert not ext.shuffles + + with pytest.raises(ValueError, match="not registered"): + ext.get_output_partition(metadata.id, 0) diff --git a/distributed/worker.py b/distributed/worker.py index 25148958f9d..72d3c93f5e3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -44,7 +44,7 @@ typename, ) -from . import comm, preloading, profile, system, utils +from . import comm, preloading, profile, shuffle, system, utils from .batched import BatchedSend from .comm import Comm, connect, get_address_host from .comm.addressing import address_from_user_args, parse_address @@ -116,6 +116,8 @@ RUNNING = {Status.running, Status.paused, Status.closing_gracefully} DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension] +if shuffle.SHUFFLE_AVAILABLE: + DEFAULT_EXTENSIONS.append(shuffle.ShuffleWorkerExtension) DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} From ccad288518eab594717fe92375a9a991e89c9592 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 24 Jan 2022 13:41:54 +0000 Subject: [PATCH 09/10] Paused workers shouldn't steal tasks (#5665) --- distributed/stealing.py | 11 ++++---- distributed/tests/test_scheduler.py | 7 +++-- distributed/tests/test_steal.py | 40 +++++++++++++++++++++-------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 101a228ce04..337ff24f756 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -15,7 +15,7 @@ from dask.utils import parse_timedelta from .comm.addressing import get_address_host -from .core import CommClosedError +from .core import CommClosedError, Status from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, recursive_to_dict @@ -393,22 +393,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): with log_errors(): i = 0 - idle = s.idle.values() - saturated = s.saturated + # Paused and closing workers must never become thieves + idle = [ws for ws in s.idle.values() if ws.status == Status.running] if not idle or len(idle) == len(s.workers): return log = [] start = time() - if not s.saturated: + saturated = s.saturated + if not saturated: saturated = topk(10, s.workers.values(), key=combined_occupancy) saturated = [ ws for ws in saturated if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads ] - elif len(s.saturated) < 20: + elif len(saturated) < 20: saturated = sorted(saturated, key=combined_occupancy, reverse=True) if len(idle) < 20: idle = sorted(idle, key=combined_occupancy) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0bb2567e053..5fcc3bb1f15 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3243,8 +3243,11 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3): while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) futures = c.map(slowinc, range(8), delay=0.1) - while (len(w1.tasks), len(w2.tasks), len(w3.tasks)) != (4, 0, 4): - await asyncio.sleep(0.01) + await wait(futures) + assert w1.data + assert not w2.data + assert w3.data + assert len(w1.data) + len(w3.data) == 8 @gen_cluster(client=True, nthreads=[("", 1)]) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 7f997ed77a1..e7f76dddc86 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -16,6 +16,7 @@ from distributed import Lock, Nanny, Worker, wait, worker_client from distributed.compatibility import LINUX, WINDOWS from distributed.config import config +from distributed.core import Status from distributed.metrics import time from distributed.scheduler import key_split from distributed.system import MEMORY_LIMIT @@ -813,28 +814,47 @@ async def test_steal_twice(c, s, a, b): while len(s.tasks) < 100: # tasks are all allocated await asyncio.sleep(0.01) + # Wait for b to start stealing tasks + while len(b.tasks) < 30: + await asyncio.sleep(0.01) # Army of new workers arrives to help - workers = await asyncio.gather(*(Worker(s.address, loop=s.loop) for _ in range(20))) + workers = await asyncio.gather(*(Worker(s.address) for _ in range(20))) await wait(futures) - has_what = dict(s.has_what) # take snapshot - empty_workers = [w for w, keys in has_what.items() if not len(keys)] - if len(empty_workers) > 2: - pytest.fail( - "Too many workers without keys (%d out of %d)" - % (len(empty_workers), len(has_what)) - ) - assert max(map(len, has_what.values())) < 30 + # Note: this includes a and b + empty_workers = [w for w, keys in s.has_what.items() if not keys] + assert ( + len(empty_workers) < 3 + ), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})" + # This also tests that some tasks were stolen from b + # (see `while len(b.tasks) < 30` above) + assert max(map(len, s.has_what.values())) < 30 assert a.in_flight_tasks == 0 assert b.in_flight_tasks == 0 - await c._close() await asyncio.gather(*(w.close() for w in workers)) +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_paused_workers_must_not_steal(c, s, w1, w2, w3): + w2.memory_pause_fraction = 1e-15 + while s.workers[w2.address].status != Status.paused: + await asyncio.sleep(0.01) + + x = c.submit(inc, 1, workers=w1.address) + await wait(x) + + futures = [c.submit(slowadd, x, i, delay=0.1) for i in range(10)] + await wait(futures) + + assert w1.data + assert not w2.data + assert w3.data + + @gen_cluster(client=True) async def test_dont_steal_already_released(c, s, a, b): future = c.submit(slowinc, 1, delay=0.05, workers=a.address) From b0b8e95bf0e7e8fae44ca019652a394dcca92353 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 24 Jan 2022 14:32:29 +0000 Subject: [PATCH 10/10] Code review --- distributed/tests/test_worker.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 720e3dc0015..b6c4961ff6d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1635,12 +1635,10 @@ async def test_close_gracefully(c, s, a, b): # be replicated by retire_workers(). while True: mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} - if len(mem) >= 8: + if len(mem) >= 8 and any(ts.state == "executing" for ts in b.tasks.values()): break await asyncio.sleep(0.01) - assert any(ts for ts in b.tasks.values() if ts.state == "executing") - await b.close_gracefully() assert b.status == Status.closed @@ -1663,16 +1661,22 @@ async def test_close_gracefully(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], timeout=10) async def test_lifetime(c, s, a): + # Note: test was occasionally failing with lifetime="1 seconds" async with Worker(s.address, lifetime="2 seconds") as b: futures = c.map(slowinc, range(200), delay=0.1, workers=[b.address]) - await asyncio.sleep(1) - assert not a.data + # Note: keys will appear in b.data several milliseconds before they switch to # status=memory in s.tasks. It's important to sample the in-memory keys from the # scheduler side, because those that the scheduler thinks are still processing # won't be replicated by retire_workers(). - mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} - assert mem + while True: + mem = {k for k, ts in s.tasks.items() if ts.state == "memory"} + if len(mem) >= 8: + break + await asyncio.sleep(0.01) + + assert b.status == Status.running + assert not a.data while b.status != Status.closed: await asyncio.sleep(0.01)