Skip to content

Commit

Permalink
Merge branch 'main' into WSMR/gather_dep
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 7, 2022
2 parents 0907ec8 + bde90af commit 7df31ed
Show file tree
Hide file tree
Showing 16 changed files with 247 additions and 194 deletions.
15 changes: 7 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand All @@ -51,6 +51,7 @@
from tornado import gen
from tornado.ioloop import PeriodicCallback

import distributed.utils
from distributed import cluster_dump, preloading
from distributed import versions as version_module
from distributed.batched import BatchedSend
Expand Down Expand Up @@ -80,8 +81,6 @@
from distributed.sizeof import sizeof
from distributed.threadpoolexecutor import rejoin
from distributed.utils import (
All,
Any,
CancelledError,
LoopRunner,
NoOpAwaitable,
Expand Down Expand Up @@ -2028,7 +2027,7 @@ async def wait(k):
logger.debug("Waiting on futures to clear before gather")

with suppress(AllExit):
await All(
await distributed.utils.All(
[wait(key) for key in keys if key in self.futures],
quiet_exceptions=AllExit,
)
Expand Down Expand Up @@ -4053,12 +4052,12 @@ def benchmark_hardware(self) -> dict:
"""
return self.sync(self.scheduler.benchmark_hardware)

def log_event(self, topic, msg):
def log_event(self, topic: str | Collection[str], msg: Any):
"""Log an event under a given topic
Parameters
----------
topic : str, list
topic : str, list[str]
Name of the topic under which to log an event. To log the same
event under multiple topics, pass a list of topic names.
msg
Expand Down Expand Up @@ -4648,9 +4647,9 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED):
)
fs = futures_of(fs)
if return_when == ALL_COMPLETED:
wait_for = All
wait_for = distributed.utils.All
elif return_when == FIRST_COMPLETED:
wait_for = Any
wait_for = distributed.utils.Any
else:
raise NotImplementedError(
"Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported"
Expand Down
10 changes: 8 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def __init__(
timeout=None,
io_loop=None,
):
if io_loop is not None:
warnings.warn(
"The io_loop kwarg to Server is ignored and will be deprecated",
DeprecationWarning,
stacklevel=2,
)

self._status = Status.init
self.handlers = {
"identity": self.identity,
Expand Down Expand Up @@ -191,8 +198,7 @@ def __init__(
self._event_finished = asyncio.Event()

self.listeners = []
self.io_loop = io_loop or IOLoop.current()
self.loop = self.io_loop
self.io_loop = self.loop = IOLoop.current()

if not hasattr(self.io_loop, "profile"):
if dask.config.get("distributed.worker.profile.enabled"):
Expand Down
6 changes: 3 additions & 3 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def transition(self, key, start, finish, **kwargs):
async def test_create_with_client(c, s):
await c.register_worker_plugin(MyPlugin(123))

worker = await Worker(s.address, loop=s.loop)
worker = await Worker(s.address)
assert worker._my_plugin_status == "setup"
assert worker._my_plugin_data == 123

Expand All @@ -55,7 +55,7 @@ async def test_remove_with_client(c, s):
await c.register_worker_plugin(MyPlugin(123), name="foo")
await c.register_worker_plugin(MyPlugin(546), name="bar")

worker = await Worker(s.address, loop=s.loop)
worker = await Worker(s.address)
# remove the 'foo' plugin
await c.unregister_worker_plugin("foo")
assert worker._my_plugin_status == "teardown"
Expand All @@ -79,7 +79,7 @@ async def test_remove_with_client(c, s):
async def test_remove_with_client_raises(c, s):
await c.register_worker_plugin(MyPlugin(123), name="foo")

worker = await Worker(s.address, loop=s.loop)
worker = await Worker(s.address)
with pytest.raises(ValueError, match="bar"):
await c.unregister_worker_plugin("bar")

Expand Down
14 changes: 10 additions & 4 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ def __init__(
config=None,
**worker_kwargs,
):
if loop is not None:
warnings.warn(
"the `loop` kwarg to `Nanny` is ignored, and will be removed in a future release. "
"The Nanny always binds to the current loop.",
DeprecationWarning,
stacklevel=2,
)

self._setup_logging(logger)
self.loop = loop or IOLoop.current()
self.loop = self.io_loop = IOLoop.current()

if isinstance(security, dict):
security = Security(**security)
Expand Down Expand Up @@ -246,9 +254,7 @@ def __init__(

self.plugins: dict[str, NannyPlugin] = {}

super().__init__(
handlers=handlers, io_loop=self.loop, connection_args=self.connection_args
)
super().__init__(handlers=handlers, connection_args=self.connection_args)

self.scheduler = self.rpc(self.scheduler_addr)
self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit)
Expand Down
31 changes: 17 additions & 14 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,7 +2887,7 @@ def __init__(
stacklevel=2,
)

self.loop = IOLoop.current()
self.loop = self.io_loop = IOLoop.current()
self._setup_logging(logger)

# Attributes
Expand Down Expand Up @@ -3070,7 +3070,7 @@ def __init__(
"get_logs": self.get_logs,
"logs": self.get_logs,
"worker_logs": self.get_worker_logs,
"log_event": self.log_worker_event,
"log_event": self.log_event,
"events": self.get_events,
"nbytes": self.get_nbytes,
"versions": self.versions,
Expand Down Expand Up @@ -3123,7 +3123,6 @@ def __init__(
self,
handlers=self.handlers,
stream_handlers=merge(worker_handlers, client_handlers),
io_loop=self.loop,
connection_limit=connection_limit,
deserialize=False,
connection_args=self.connection_args,
Expand Down Expand Up @@ -6224,7 +6223,11 @@ async def feed(
if teardown:
teardown(self, state)

def log_worker_event(self, worker=None, topic=None, msg=None):
def log_worker_event(
self, worker: str, topic: str | Collection[str], msg: Any
) -> None:
if isinstance(msg, dict):
msg["worker"] = worker
self.log_event(topic, msg)

def subscribe_worker_status(self, comm=None):
Expand Down Expand Up @@ -6906,21 +6909,21 @@ async def get_worker_logs(self, n=None, workers=None, nanny=False):
)
return results

def log_event(self, name, msg):
def log_event(self, topic: str | Collection[str], msg: Any) -> None:
event = (time(), msg)
if isinstance(name, (list, tuple)):
for n in name:
self.events[n].append(event)
self.event_counts[n] += 1
self._report_event(n, event)
if not isinstance(topic, str):
for t in topic:
self.events[t].append(event)
self.event_counts[t] += 1
self._report_event(t, event)
else:
self.events[name].append(event)
self.event_counts[name] += 1
self._report_event(name, event)
self.events[topic].append(event)
self.event_counts[topic] += 1
self._report_event(topic, event)

for plugin in list(self.plugins.values()):
try:
plugin.log_event(name, msg)
plugin.log_event(topic, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)

Expand Down
12 changes: 6 additions & 6 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2980,7 +2980,7 @@ async def test_unrunnable_task_runs(c, s, a, b):
assert s.tasks[x.key] in s.unrunnable
assert s.get_task_status(keys=[x.key]) == {x.key: "no-worker"}

w = await Worker(s.address, loop=s.loop)
w = await Worker(s.address)

while x.status != "finished":
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -6603,8 +6603,8 @@ def setup(self, worker=None):
await c.register_worker_plugin(MyPlugin())


@gen_cluster(client=True)
async def test_log_event(c, s, a, b):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_log_event(c, s, a):

# Log an event from inside a task
def foo():
Expand All @@ -6614,7 +6614,7 @@ def foo():
await c.submit(foo)
events = await c.get_events("topic1")
assert len(events) == 1
assert events[0][1] == {"foo": "bar"}
assert events[0][1] == {"foo": "bar", "worker": a.address}

# Log an event while on the scheduler
def log_scheduler(dask_scheduler):
Expand Down Expand Up @@ -7135,7 +7135,7 @@ def user_event_handler(event):

time_, msg = log[0]
assert isinstance(time_, float)
assert msg == {"important": "event"}
assert msg == {"important": "event", "worker": a.address}

c.unsubscribe_topic("test-topic")

Expand Down Expand Up @@ -7166,7 +7166,7 @@ async def async_user_event_handler(event):
assert len(log) == 2
time_, msg = log[1]
assert isinstance(time_, float)
assert msg == {"async": "event"}
assert msg == {"async": "event", "worker": a.address}

# Even though the middle event was not subscribed to, the scheduler still
# knows about all and we can retrieve them
Expand Down
12 changes: 6 additions & 6 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ async def test_wait_for_scheduler():

@gen_cluster(nthreads=[], client=True)
async def test_environment_variable(c, s):
a = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "123"})
b = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "456"})
a = Nanny(s.address, memory_limit=0, env={"FOO": "123"})
b = Nanny(s.address, memory_limit=0, env={"FOO": "456"})
await asyncio.gather(a, b)
results = await c.run(lambda: os.environ["FOO"])
assert results == {a.worker_address: "123", b.worker_address: "456"}
Expand All @@ -288,18 +288,18 @@ async def test_environment_variable_by_config(c, s, monkeypatch):

with dask.config.set({"distributed.nanny.environ": "456"}):
with pytest.raises(TypeError, match="configuration must be of type dict"):
Nanny(s.address, loop=s.loop, memory_limit=0)
Nanny(s.address, memory_limit=0)

with dask.config.set({"distributed.nanny.environ": {"FOO": "456"}}):

# precedence
# kwargs > env var > config

with mock.patch.dict(os.environ, {"FOO": "BAR"}, clear=True):
a = Nanny(s.address, loop=s.loop, memory_limit=0, env={"FOO": "123"})
x = Nanny(s.address, loop=s.loop, memory_limit=0)
a = Nanny(s.address, memory_limit=0, env={"FOO": "123"})
x = Nanny(s.address, memory_limit=0)

b = Nanny(s.address, loop=s.loop, memory_limit=0)
b = Nanny(s.address, memory_limit=0)

await asyncio.gather(a, b, x)
results = await c.run(lambda: os.environ["FOO"])
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ async def test_file_descriptors(c, s):
num_fds_1 = proc.num_fds()

N = 20
nannies = await asyncio.gather(*(Nanny(s.address, loop=s.loop) for _ in range(N)))
nannies = await asyncio.gather(*(Nanny(s.address) for _ in range(N)))

while len(s.workers) < N:
await asyncio.sleep(0.1)
Expand Down Expand Up @@ -2234,7 +2234,7 @@ async def test_worker_name_collision(s, a):
with raises_with_cause(
RuntimeError, None, ValueError, f"name taken, {a.name!r}"
):
await Worker(s.address, name=a.name, loop=s.loop, host="127.0.0.1")
await Worker(s.address, name=a.name, host="127.0.0.1")

s.validate_state()
assert set(s.workers) == {a.address}
Expand Down
14 changes: 3 additions & 11 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ async def test_new_worker_steals(c, s, a):
while len(a.tasks) < 10:
await asyncio.sleep(0.01)

b = await Worker(s.address, loop=s.loop, nthreads=1, memory_limit=MEMORY_LIMIT)
b = await Worker(s.address, nthreads=1, memory_limit=MEMORY_LIMIT)

result = await total
assert result == sum(map(inc, range(100)))
Expand Down Expand Up @@ -479,7 +479,7 @@ async def test_steal_resource_restrictions(c, s, a):
await asyncio.sleep(0.01)
assert len(a.tasks) == 101

b = await Worker(s.address, loop=s.loop, nthreads=1, resources={"A": 4})
b = await Worker(s.address, nthreads=1, resources={"A": 4})

while not b.tasks or len(a.tasks) == 101:
await asyncio.sleep(0.01)
Expand All @@ -501,15 +501,7 @@ async def test_steal_resource_restrictions_asym_diff(c, s, a):
await asyncio.sleep(0.01)
assert len(a.tasks) == 101

b = await Worker(
s.address,
loop=s.loop,
nthreads=1,
resources={
"A": 4,
"B": 5,
},
)
b = await Worker(s.address, nthreads=1, resources={"A": 4, "B": 5})

while not b.tasks or len(a.tasks) == 101:
await asyncio.sleep(0.01)
Expand Down

0 comments on commit 7df31ed

Please sign in to comment.