Skip to content

Commit

Permalink
Merge pull request #1720 from bluesky/1641_collect_while_completing
Browse files Browse the repository at this point in the history
Add collect_while_completing plan stub and tests
  • Loading branch information
jsouter committed May 21, 2024
2 parents a17729f + bb7788f commit 27ca36c
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/multi_run_plans.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ In the plan, the run may be defined by explicitely enclosing the code in `bps.op
RE(sample_plan())
or using `@bpp.run_decorator`, which inserts `open_run` and `close_run` control messages
before and after the sequnce generated by the enclosed code:
before and after the sequence generated by the enclosed code:

.. code-block:: python
Expand Down
44 changes: 41 additions & 3 deletions src/bluesky/plan_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
merge_cycler,
plan,
separate_devices,
short_uid,
)
from .utils import (
short_uid as _short_uid,
Expand Down Expand Up @@ -505,7 +506,7 @@ def sleep(time):


@plan
def wait(group=None, *, timeout=None):
def wait(group=None, *, timeout=None, move_on=False):
"""
Wait for all statuses in a group to report being finished.
Expand All @@ -517,9 +518,9 @@ def wait(group=None, *, timeout=None):
Yields
------
msg : Msg
Msg('wait', None, group=group)
Msg('wait', None, group=group, move_on=move_on, timeout=timeout)
"""
return (yield Msg("wait", None, group=group, timeout=timeout))
return (yield Msg("wait", None, group=group, move_on=move_on, timeout=timeout))


_wait = wait # for internal references to avoid collision with 'wait' kwarg
Expand Down Expand Up @@ -842,6 +843,43 @@ def collect(obj, *args, stream=False, return_payload=True, name=None):
return (yield Msg("collect", obj, *args, stream=stream, return_payload=return_payload, name=name))


@plan
def collect_while_completing(flyers, dets, flush_period=None, stream_name=None):
"""
Collect data from one or more fly-scanning devices and emit documents, then collect and emit
data from one or more Collectable detectors until all are done.
Parameters
----------
flyers: An iterable sequence of fly-able devices with 'kickoff', 'complete' and
'collect' methods.
dets: An iterable sequence of collectable devices with 'describe_collect' method.
flush_period: float, int
Time period in seconds between each yield from collect while waiting for triggered
objects to be done
stream_name: str, optional
If not None, will collect for the named string specifically, else collect will be performed
on all streams.
Yields
------
msg : Msg
A 'complete' message or 'collect' message
See Also
--------
:func:`bluesky.plan_stubs.complete`
:func:`bluesky.plan_stubs.collect`
"""
group = short_uid(label="complete")
yield from complete_all(*flyers, group=group, wait=False)
done = False
while not done:
done = yield from wait(group=group, timeout=flush_period, move_on=True)
yield from collect(*dets, name=stream_name)


@plan
def configure(obj, *args, **kwargs):
"""
Expand Down
39 changes: 31 additions & 8 deletions src/bluesky/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def setup_run_permit():
self._groups = defaultdict(set) # sets of Events to wait for
self._status_objs = defaultdict(set) # status objects to wait for
self._temp_callback_ids = set() # ids from CallbackRegistry
self._seen_wait_and_move_on_keys = set() # group ids that have been passed to _wait_and_move_on
self._msg_cache = deque() # history of processed msgs for rewinding
self._rewindable_flag = True # if the RE is allowed to replay msgs
self._plan_stack = deque() # stack of generators to work off of
Expand Down Expand Up @@ -2270,43 +2271,65 @@ async def _trigger(self, msg):

return ret

def _call_waiting_hook(self, *args, **kwargs):
if self.waiting_hook is not None:
self.waiting_hook(*args, **kwargs)

async def _wait(self, msg):
"""Block progress until every object that was triggered or set
with the keyword argument `group=<GROUP>` is done.
with the keyword argument `group=<GROUP>` is done. Returns a boolean that is
true when all triggered objects are done. When the keyword argument
`move_on=<MOVE_ON>` is true, this method can return before all objects are done
after a flush period given by the `timeout=<TIMEOUT>` keyword argument.
Expected message object is:
Msg('wait', group=<GROUP>)
Msg('wait', group=<GROUP>, move_on=<MOVE_ON>)
where ``<GROUP>`` is any hashable key.
where ``<GROUP>`` is any hashable key and ``<MOVE_ON>`` is a boolean.
"""
done = False # boolean that tracks whether waiting is complete
if msg.args:
(group,) = msg.args
move_on = False
else:
group = msg.kwargs["group"]
move_on = msg.kwargs.get("move_on", False)
futs = list(self._groups.pop(group, []))
if futs:
status_objs = self._status_objs.pop(group)
try:
if self.waiting_hook is not None:
if move_on:
if group not in self._seen_wait_and_move_on_keys:
self._seen_wait_and_move_on_keys.add(group)
self._call_waiting_hook(status_objs)
else: # if move_on False
# Notify the waiting_hook function that the RunEngine is
# waiting for these status_objs to complete. Users can use
# the information these encapsulate to create a progress
# bar.
self.waiting_hook(status_objs)
self._call_waiting_hook(status_objs)
await self._wait_for(Msg("wait_for", None, futs, timeout=msg.kwargs.get("timeout", None)))
except WaitForTimeoutError:
# We might wait to call wait again, so put the futures and status objects back in
self._groups[group] = futs
self._status_objs[group] = status_objs
raise
if not move_on:
raise
finally:
if self.waiting_hook is not None:
if not move_on:
# Notify the waiting_hook function that we have moved on by
# sending it `None`. If all goes well, it could have
# inferred this from the status_obj, but there are edge
# cases.
self.waiting_hook(None)
self._call_waiting_hook(None)
done = True
else:
done = all(obj.done for obj in status_objs)
if done:
self._call_waiting_hook(None)
self._seen_wait_and_move_on_keys.remove(group)
return done

def _status_object_completed(self, ret, p_event, pardon_failures):
"""
Expand Down
92 changes: 90 additions & 2 deletions src/bluesky/tests/test_flyer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,108 @@
import functools
from collections import defaultdict
from time import time
from typing import Dict

import pytest
from event_model.documents.event import PartialEvent
from ophyd import Component as Cpt
from ophyd import Device
from ophyd.sim import NullStatus, TrivialFlyer
from ophyd.sim import NullStatus, StatusBase, TrivialFlyer

from bluesky import Msg
from bluesky.plan_stubs import close_run, complete, complete_all, kickoff, kickoff_all, open_run, wait
from bluesky.plan_stubs import (
close_run,
collect_while_completing,
complete,
complete_all,
declare_stream,
kickoff,
kickoff_all,
open_run,
wait,
)
from bluesky.plans import count, fly
from bluesky.protocols import Preparable
from bluesky.run_engine import IllegalMessageSequence
from bluesky.tests import requires_ophyd
from bluesky.tests.utils import DocCollector


def collect_while_completing_plan(flyers, dets, stream_name: str = "test_stream", pre_declare: bool = True):
yield from open_run()
if pre_declare:
yield from declare_stream(*dets, name=stream_name, collect=True)
yield from collect_while_completing(flyers, dets, flush_period=0.1, stream_name=stream_name)
yield from close_run()


def test_collect_while_completing_plan_trivial_case(RE):
completed = []
collected = []

class StatusDoneAfterTenthCall(StatusBase):
times_called = 0

@property
def done(self):
self.times_called += 1
return self.times_called >= 10

class SlowFlyer:
name = "trivial-flyer"
custom_status = StatusDoneAfterTenthCall()

def kickoff(self):
return NullStatus()

def complete(self):
completed.append(self)
return self.custom_status

class TrivialDetector:
name = "trivial-detector"
times_collected = 0

def describe_collect(self):
collected.append(self)
return {
"times_collected": {
"dims": [],
"dtype": "number",
"shape": [],
"source": "times_collected",
}
}

def collect(self):
self.times_collected += 1
yield PartialEvent(
data={"times_collected": self.times_collected},
timestamps={"times_collected": time()},
)

det = TrivialDetector()
flyer = SlowFlyer()
docs = defaultdict(list)

def assert_emitted(docs: Dict[str, list], **numbers: int):
assert list(docs) == list(numbers)
assert {name: len(d) for name, d in docs.items()} == numbers

RE(collect_while_completing_plan([flyer], [det]), lambda name, doc: docs[name].append(doc))
for idx, event in enumerate(docs["event_page"]):
print(idx, event)
assert "times_collected" in event["data"] and event["data"]["times_collected"] == [idx + 1]
# The detector will be collected nine times as the flyer's done property is
# checked once during the initial complete call, then the detector collects
# every loop and checks if the flyer is done until it is checked for a tenth time
assert_emitted(docs, start=1, descriptor=1, event_page=9, stop=1)
# key should be removed from set when collection done
assert not RE._seen_wait_and_move_on_keys
assert det in collected
assert flyer in completed


@requires_ophyd
def test_flyer_with_collect_asset_documents(RE):
from ophyd.sim import det, new_trivial_flyer, trivial_flyer
Expand Down
20 changes: 16 additions & 4 deletions src/bluesky/tests/test_new_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@
(trigger, ("det",), {}, [Msg("trigger", "det", group=None)]),
(trigger, ("det",), {"group": "A"}, [Msg("trigger", "det", group="A")]),
(sleep, (2,), {}, [Msg("sleep", None, 2)]),
(wait, (), {}, [Msg("wait", None, group=None, timeout=None)]),
(wait, ("A",), {}, [Msg("wait", None, group="A", timeout=None)]),
(wait, (), {}, [Msg("wait", None, move_on=False, group=None, timeout=None)]),
(wait, ("A",), {}, [Msg("wait", None, group="A", move_on=False, timeout=None)]),
(checkpoint, (), {}, [Msg("checkpoint")]),
(clear_checkpoint, (), {}, [Msg("clear_checkpoint")]),
(pause, (), {}, [Msg("pause", None, defer=False)]),
Expand Down Expand Up @@ -693,14 +693,26 @@ def plan():
def test_trigger_and_read(hw):
det = hw.det
msgs = list(trigger_and_read([det]))
expected = [Msg("trigger", det), Msg("wait"), Msg("create", name="primary"), Msg("read", det), Msg("save")]
expected = [
Msg("trigger", det),
Msg("wait", move_on=False),
Msg("create", name="primary"),
Msg("read", det),
Msg("save"),
]
for msg in msgs:
msg.kwargs.pop("group", None)
msg.kwargs.pop("timeout", None)
assert msgs == expected

msgs = list(trigger_and_read([det], "custom"))
expected = [Msg("trigger", det), Msg("wait"), Msg("create", name="custom"), Msg("read", det), Msg("save")]
expected = [
Msg("trigger", det),
Msg("wait", move_on=False),
Msg("create", name="custom"),
Msg("read", det),
Msg("save"),
]
for msg in msgs:
msg.kwargs.pop("group", None)
msg.kwargs.pop("timeout", None)
Expand Down

0 comments on commit 27ca36c

Please sign in to comment.