Skip to content

Commit

Permalink
Update broadcast array for numpy 2 (#11096)
Browse files Browse the repository at this point in the history
  • Loading branch information
quasiben committed May 6, 2024
1 parent 067c668 commit dc1a5a3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
einsum_lookup,
tensordot_lookup,
)
from dask.array.numpy_compat import _Recurser
from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
from dask.base import (
DaskMethodsMixin,
Expand Down Expand Up @@ -5035,7 +5035,10 @@ def broadcast_arrays(*args, subok=False):
shape = broadcast_shapes(*(e.shape for e in args))
chunks = broadcast_chunks(*(e.chunks for e in args))

result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
if NUMPY_GE_200:
result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
else:
result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]

return result

Expand Down
2 changes: 1 addition & 1 deletion dask/array/tests/test_array_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def test_broadcast_arrays():
a_r = np.broadcast_arrays(a_0, a_1)
d_r = da.broadcast_arrays(d_a_0, d_a_1)

assert isinstance(d_r, list)
assert isinstance(d_r, (list, tuple))
assert len(a_r) == len(d_r)

for e_a_r, e_d_r in zip(a_r, d_r):
Expand Down

0 comments on commit dc1a5a3

Please sign in to comment.