Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that repack collections only return tuple if necessary #11004

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 14 additions & 6 deletions dask/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def persist(self, **kwargs):
--------
dask.persist
"""
(result,) = persist(self, traverse=False, **kwargs)
return result
return persist(self, traverse=False, **kwargs)

def compute(self, **kwargs):
"""Compute this dask collection
Expand All @@ -372,8 +371,7 @@ def compute(self, **kwargs):
--------
dask.compute
"""
(result,) = compute(self, traverse=False, **kwargs)
return result
return compute(self, traverse=False, **kwargs)

def __await__(self):
try:
Expand Down Expand Up @@ -460,7 +458,7 @@ def _extract_graph_and_keys(vals):
return graph, keys


def unpack_collections(*args, traverse=True):
def unpack_collections(arg, *args, traverse=True):
"""Extract collections in preparation for compute/persist/etc...

Intended use is to find all collections in a set of (possibly nested)
Expand All @@ -486,6 +484,11 @@ def unpack_collections(*args, traverse=True):
A function to call on the transformed collections to repackage them as
they were in the original ``args``.
"""
return_tuple = False
if args:
return_tuple = True

args = (arg,) + args

collections = []
repack_dsk = {}
Expand Down Expand Up @@ -537,7 +540,12 @@ def _unpack(expr):
def repack(results):
dsk = repack_dsk.copy()
dsk[collections_token] = quote(results)
return simple_get(dsk, out)
res = simple_get(dsk, out)
if return_tuple:
return res
else:
assert len(res) == 1
return res[0]

return collections, repack

Expand Down
9 changes: 6 additions & 3 deletions dask/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def build(a, b, c, iterator):
return t

args = build(a, b, c, (i for i in [a, b, c]))
collections, repack = unpack_collections(a)
assert len(collections) == 1
assert repack(collections) is a

collections, repack = unpack_collections(*args)
assert len(collections) == 3
Expand Down Expand Up @@ -767,7 +770,7 @@ def test_persist_delayed():
x1 = delayed(1)
x2 = delayed(inc)(x1)
x3 = delayed(inc)(x2)
(xx,) = persist(x3)
xx = persist(x3)
assert isinstance(xx, Delayed)
assert xx.key == x3.key
assert len(xx.dask) == 1
Expand Down Expand Up @@ -806,7 +809,7 @@ def test_persist_delayed_rename(key, rename, new_key):

def test_persist_delayedleaf():
x = delayed(1)
(xx,) = persist(x)
xx = persist(x)
assert isinstance(xx, Delayed)
assert xx.compute() == 1

Expand All @@ -816,7 +819,7 @@ class C:
x = 1

x = delayed(C).x
(xx,) = persist(x)
xx = persist(x)
assert isinstance(xx, Delayed)
assert xx.compute() == 1

Expand Down
4 changes: 2 additions & 2 deletions dask/tests/test_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_common_subexpressions():

def test_delayed_optimize():
x = Delayed("b", {"a": 1, "b": (inc, "a"), "c": (inc, "b")})
(x2,) = dask.optimize(x)
x2 = dask.optimize(x)
# Delayed's __dask_optimize__ culls out 'c'
assert sorted(x2.dask.keys()) == ["a", "b"]
assert x2._layer != x2._key
Expand Down Expand Up @@ -836,7 +836,7 @@ def test_annotations_survive_optimization():

# Ensure optimizing a Delayed object returns a HighLevelGraph
# and doesn't loose annotations
(d_opt,) = dask.optimize(d)
d_opt = dask.optimize(d)
assert type(d_opt.dask) is HighLevelGraph
assert len(d_opt.dask.layers) == 1
assert len(d_opt.dask.layers["b"]) == 2 # c is culled
Expand Down
5 changes: 3 additions & 2 deletions dask/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def test_can_import_nested_things():
@gen_cluster(client=True)
async def test_persist(c, s, a, b):
x = delayed(inc)(1)
(x2,) = persist(x)
x2 = persist(x)

await wait(x2)
assert x2.key in a.data or x2.key in b.data

y = delayed(inc)(10)
y2, one = persist(y, 1)
assert one == 1

await wait(y2)
assert y2.key in a.data or y2.key in b.data
Expand Down Expand Up @@ -580,7 +581,7 @@ def test_blockwise_different_optimization(c):
v = da.from_array(np.array([10 + 2j, 7 - 3j, 8 + 1j]))
cv = v.conj()
x = u * cv
(cv,) = dask.optimize(cv)
cv = dask.optimize(cv)
y = u * cv
expected = np.array([0 + 0j, 7 + 3j, 16 - 2j])
with dask.config.set({"optimization.fuse.active": False}):
Expand Down