Skip to content

Commit

Permalink
Enable column projection for groupby slicing (#9667)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Nov 30, 2022
1 parent 8ef438b commit 49b516b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
28 changes: 23 additions & 5 deletions dask/dataframe/groupby.py
Expand Up @@ -9,7 +9,7 @@
import pandas as pd

from dask import config
from dask.base import tokenize
from dask.base import is_dask_collection, tokenize
from dask.dataframe._compat import (
PANDAS_GT_140,
PANDAS_GT_150,
Expand All @@ -34,6 +34,7 @@
PANDAS_GT_110,
insert_meta_param_description,
is_dataframe_like,
is_index_like,
is_series_like,
make_meta,
raise_on_meta_error,
Expand Down Expand Up @@ -1245,9 +1246,29 @@ def __init__(
if any(isinstance(key, pd.Grouper) for key in by_):
raise NotImplementedError("pd.Grouper is currently not supported by Dask.")

# slicing key applied to _GroupBy instance
self._slice = slice

# Check if we can project columns
projection = None
if (
np.isscalar(self._slice)
or isinstance(self._slice, (str, list, tuple))
or (
(is_index_like(self._slice) or is_series_like(self._slice))
and not is_dask_collection(self._slice)
)
):
projection = set(by_).union(
{self._slice}
if (np.isscalar(self._slice) or isinstance(self._slice, str))
else self._slice
)
projection = [c for c in df.columns if c in projection]

assert isinstance(df, (DataFrame, Series))
self.group_keys = group_keys
self.obj = df
self.obj = df[projection] if projection else df
# grouping key passed via groupby method
self.by = _normalize_by(df, by)
self.sort = sort
Expand All @@ -1262,9 +1283,6 @@ def __init__(
"The grouped object and 'by' of the groupby must have the same divisions."
)

# slicing key applied to _GroupBy instance
self._slice = slice

if isinstance(self.by, list):
by_meta = [
item._meta if isinstance(item, Series) else item for item in self.by
Expand Down
31 changes: 31 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Expand Up @@ -3075,3 +3075,34 @@ def test_groupby_None_split_out_warns():
ddf = dd.from_pandas(df, npartitions=1)
with pytest.warns(FutureWarning, match="split_out=None"):
ddf.groupby("a").agg({"b": "max"}, split_out=None)


@pytest.mark.parametrize("by", ["key1", ["key1", "key2"]])
@pytest.mark.parametrize(
"slice_key",
[
3,
"value",
["value"],
("value",),
pd.Index(["value"]),
pd.Series(["value"]),
],
)
def test_groupby_slice_getitem(by, slice_key):
pdf = pd.DataFrame(
{
"key1": ["a", "b", "a"],
"key2": ["c", "c", "c"],
"value": [1, 2, 3],
3: [1, 2, 3],
}
)
ddf = dd.from_pandas(pdf, npartitions=3)
expect = pdf.groupby(by)[slice_key].count()
got = ddf.groupby(by)[slice_key].count()

# We should have a getitem layer, enabling
# column projection after read_parquet etc
assert hlg_layer(got.dask, "getitem")
assert_eq(expect, got)

0 comments on commit 49b516b

Please sign in to comment.