From 49b516b43e44caf3091340d15247d4fe80e8a5c3 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Wed, 30 Nov 2022 13:04:32 -0600 Subject: [PATCH] Enable column projection for groupby slicing (#9667) --- dask/dataframe/groupby.py | 28 ++++++++++++++++++++----- dask/dataframe/tests/test_groupby.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/dask/dataframe/groupby.py b/dask/dataframe/groupby.py index 7fb4fa637a6..1ba694bc0a3 100644 --- a/dask/dataframe/groupby.py +++ b/dask/dataframe/groupby.py @@ -9,7 +9,7 @@ import pandas as pd from dask import config -from dask.base import tokenize +from dask.base import is_dask_collection, tokenize from dask.dataframe._compat import ( PANDAS_GT_140, PANDAS_GT_150, @@ -34,6 +34,7 @@ PANDAS_GT_110, insert_meta_param_description, is_dataframe_like, + is_index_like, is_series_like, make_meta, raise_on_meta_error, @@ -1245,9 +1246,29 @@ def __init__( if any(isinstance(key, pd.Grouper) for key in by_): raise NotImplementedError("pd.Grouper is currently not supported by Dask.") + # slicing key applied to _GroupBy instance + self._slice = slice + + # Check if we can project columns + projection = None + if ( + np.isscalar(self._slice) + or isinstance(self._slice, (str, list, tuple)) + or ( + (is_index_like(self._slice) or is_series_like(self._slice)) + and not is_dask_collection(self._slice) + ) + ): + projection = set(by_).union( + {self._slice} + if (np.isscalar(self._slice) or isinstance(self._slice, str)) + else self._slice + ) + projection = [c for c in df.columns if c in projection] + assert isinstance(df, (DataFrame, Series)) self.group_keys = group_keys - self.obj = df + self.obj = df[projection] if projection else df # grouping key passed via groupby method self.by = _normalize_by(df, by) self.sort = sort @@ -1262,9 +1283,6 @@ def __init__( "The grouped object and 'by' of the groupby must have the same divisions." ) - # slicing key applied to _GroupBy instance - self._slice = slice - if isinstance(self.by, list): by_meta = [ item._meta if isinstance(item, Series) else item for item in self.by diff --git a/dask/dataframe/tests/test_groupby.py b/dask/dataframe/tests/test_groupby.py index 108e30c7706..2688c9b58ce 100644 --- a/dask/dataframe/tests/test_groupby.py +++ b/dask/dataframe/tests/test_groupby.py @@ -3075,3 +3075,34 @@ def test_groupby_None_split_out_warns(): ddf = dd.from_pandas(df, npartitions=1) with pytest.warns(FutureWarning, match="split_out=None"): ddf.groupby("a").agg({"b": "max"}, split_out=None) + + +@pytest.mark.parametrize("by", ["key1", ["key1", "key2"]]) +@pytest.mark.parametrize( + "slice_key", + [ + 3, + "value", + ["value"], + ("value",), + pd.Index(["value"]), + pd.Series(["value"]), + ], +) +def test_groupby_slice_getitem(by, slice_key): + pdf = pd.DataFrame( + { + "key1": ["a", "b", "a"], + "key2": ["c", "c", "c"], + "value": [1, 2, 3], + 3: [1, 2, 3], + } + ) + ddf = dd.from_pandas(pdf, npartitions=3) + expect = pdf.groupby(by)[slice_key].count() + got = ddf.groupby(by)[slice_key].count() + + # We should have a getitem layer, enabling + # column projection after read_parquet etc + assert hlg_layer(got.dask, "getitem") + assert_eq(expect, got)