Skip to content

Commit

Permalink
Move tests (#8631)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Apr 30, 2024
1 parent 5a588ae commit 4986fa4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 98 deletions.
90 changes: 84 additions & 6 deletions distributed/shuffle/tests/test_shuffle.py
Expand Up @@ -1232,12 +1232,90 @@ async def test_head(c, s, a, b):


def test_split_by_worker():
workers = ["a", "b", "c"]
npartitions = 5
df = pd.DataFrame({"x": range(100), "y": range(100)})
df["_partitions"] = df.x % npartitions
worker_for = {i: random.choice(workers) for i in range(npartitions)}
s = pd.Series(worker_for, name="_worker").astype("category")
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert set(out) == {"alice", "bob"}
assert list(out["alice"].to_pandas().columns) == list(df.columns)

assert sum(map(len, out.values())) == len(df)


def test_split_by_worker_empty():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert out == {}


def test_split_by_worker_many_workers():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [5, 7, 5, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out

assert sum(map(len, out.values())) == len(df)


@pytest.mark.parametrize("drop_column", [True, False])
def test_split_by_partition(drop_column):
pa = pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [3, 1, 2, 3, 1],
}
)
t = pa.Table.from_pandas(df)

out = split_by_partition(t, "_partition", drop_column)
assert set(out) == {1, 2, 3}
if drop_column:
df = df.drop(columns="_partition")
assert out[1].column_names == list(df.columns)
assert sum(map(len, out.values())) == len(df)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
Expand Down
92 changes: 0 additions & 92 deletions distributed/shuffle/tests/test_shuffle_plugins.py
Expand Up @@ -5,11 +5,6 @@
import pytest

from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import (
_get_worker_for_range_sharding,
split_by_partition,
split_by_worker,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils_test import gen_cluster

Expand All @@ -35,90 +30,3 @@ async def test_installation_on_scheduler(s, a):
assert isinstance(ext, ShuffleSchedulerPlugin)
assert s.handlers["shuffle_barrier"] == ext.barrier
assert s.handlers["shuffle_get"] == ext.get


def test_split_by_worker():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert set(out) == {"alice", "bob"}
assert list(out["alice"].to_pandas().columns) == list(df.columns)

assert sum(map(len, out.values())) == len(df)


def test_split_by_worker_empty():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert out == {}


def test_split_by_worker_many_workers():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [5, 7, 5, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out

assert sum(map(len, out.values())) == len(df)


@pytest.mark.parametrize("drop_column", [True, False])
def test_split_by_partition(drop_column):
pa = pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [3, 1, 2, 3, 1],
}
)
t = pa.Table.from_pandas(df)

out = split_by_partition(t, "_partition", drop_column)
assert set(out) == {1, 2, 3}
if drop_column:
df = df.drop(columns="_partition")
assert out[1].column_names == list(df.columns)
assert sum(map(len, out.values())) == len(df)

0 comments on commit 4986fa4

Please sign in to comment.