diff --git a/dask/dataframe/groupby.py b/dask/dataframe/groupby.py index d958f3a77c8..e5e601ed914 100644 --- a/dask/dataframe/groupby.py +++ b/dask/dataframe/groupby.py @@ -10,7 +10,11 @@ from dask import config from dask.base import tokenize -from dask.dataframe._compat import PANDAS_GT_150, check_numeric_only_deprecation +from dask.dataframe._compat import ( + PANDAS_GT_140, + PANDAS_GT_150, + check_numeric_only_deprecation, +) from dask.dataframe.core import ( GROUP_KEYS_DEFAULT, DataFrame, @@ -37,6 +41,9 @@ from dask.highlevelgraph import HighLevelGraph from dask.utils import M, _deprecated, derived_from, funcname, itemgetter +if PANDAS_GT_140: + from pandas.core.apply import reconstruct_func, validate_func_kwargs + # ############################################# # # GroupBy implementation notes @@ -1162,13 +1169,14 @@ def wrapper(func): {based_on_str} Parameters ---------- - arg : callable, str, list or dict + arg : callable, str, list or dict, optional Aggregation spec. Accepted combinations are: - callable function - string function name - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` - dict of column names -> function, function name or list of such. + - None only if named aggregation syntax is used split_every : int, optional Number of intermediate partitions that may be aggregated at once. This defaults to 8. If your intermediate partitions are likely to @@ -1185,6 +1193,12 @@ def wrapper(func): ``split_out = 1``. When ``split_out > 1``, it chooses the algorithm set by the ``shuffle`` option in the dask config system, or ``"tasks"`` if nothing is set. + kwargs: tuple or pd.NamedAgg, optional + Used for named aggregations where the keywords are the output column + names and the values are tuples where the first element is the input + column name and the second element is the aggregation function. + ``pandas.NamedAgg`` can also be used as the value. To use the named + aggregation syntax, arg must be set to None. """ return func @@ -1881,7 +1895,9 @@ def get_group(self, key): ) @_aggregate_docstring() - def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): + def aggregate( + self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs + ): if split_out is None: warnings.warn( "split_out=None is deprecated, please use a positive integer, " @@ -1891,7 +1907,20 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): split_out = 1 shuffle = _determine_shuffle(shuffle, split_out) + relabeling = None + columns = None + order = None column_projection = None + if PANDAS_GT_140: + if isinstance(self, DataFrameGroupBy): + if arg is None: + relabeling, arg, columns, order = reconstruct_func(arg, **kwargs) + + elif isinstance(self, SeriesGroupBy): + relabeling = arg is None + if relabeling: + columns, arg = validate_func_kwargs(kwargs) + if isinstance(self.obj, DataFrame): if isinstance(self.by, tuple) or np.isscalar(self.by): group_columns = {self.by} @@ -1985,10 +2014,11 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): # for larger values of split_out. However, the shuffle # step requires that the result of `chunk` produces a # proper DataFrame type + # If we have a median in the spec, we cannot do an initial # aggregation. if has_median: - return _shuffle_aggregate( + result = _shuffle_aggregate( chunk_args, chunk=_non_agg_chunk, chunk_kwargs={ @@ -2012,7 +2042,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): sort=self.sort, ) else: - return _shuffle_aggregate( + result = _shuffle_aggregate( chunk_args, chunk=_groupby_apply_funcs, chunk_kwargs={ @@ -2035,49 +2065,56 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): shuffle=shuffle if isinstance(shuffle, str) else "tasks", sort=self.sort, ) + else: + if self.sort is None and split_out > 1: + warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning) + + # Check sort behavior + if self.sort and split_out > 1: + raise NotImplementedError( + "Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`" + " Try using `shuffle=True` if you are grouping on a single column." + " Otherwise, try using split_out=1, or grouping with sort=False." + ) - if self.sort is None and split_out > 1: - warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning) - - # Check sort behavior - if self.sort and split_out > 1: - raise NotImplementedError( - "Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`" - " Try using `shuffle=True` if you are grouping on a single column." - " Otherwise, try using split_out=1, or grouping with sort=False." + result = aca( + chunk_args, + chunk=_groupby_apply_funcs, + chunk_kwargs=dict( + funcs=chunk_funcs, + sort=False, + **self.observed, + **self.dropna, + ), + combine=_groupby_apply_funcs, + combine_kwargs=dict( + funcs=aggregate_funcs, + level=levels, + sort=False, + **self.observed, + **self.dropna, + ), + aggregate=_agg_finalize, + aggregate_kwargs=dict( + aggregate_funcs=aggregate_funcs, + finalize_funcs=finalizers, + level=levels, + **self.observed, + **self.dropna, + ), + token="aggregate", + split_every=split_every, + split_out=split_out, + split_out_setup=split_out_on_index, + sort=self.sort, ) - return aca( - chunk_args, - chunk=_groupby_apply_funcs, - chunk_kwargs=dict( - funcs=chunk_funcs, - sort=False, - **self.observed, - **self.dropna, - ), - combine=_groupby_apply_funcs, - combine_kwargs=dict( - funcs=aggregate_funcs, - level=levels, - sort=False, - **self.observed, - **self.dropna, - ), - aggregate=_agg_finalize, - aggregate_kwargs=dict( - aggregate_funcs=aggregate_funcs, - finalize_funcs=finalizers, - level=levels, - **self.observed, - **self.dropna, - ), - token="aggregate", - split_every=split_every, - split_out=split_out, - split_out_setup=split_out_on_index, - sort=self.sort, - ) + if relabeling and result is not None: + if order is not None: + result = result.iloc[:, order] + result.columns = columns + + return result @insert_meta_param_description(pad=12) def apply(self, func, *args, **kwargs): @@ -2501,18 +2538,28 @@ def __getattr__(self, key): raise AttributeError(e) from e @_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.aggregate") - def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): + def aggregate( + self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs + ): if arg == "size": return self.size() return super().aggregate( - arg, split_every=split_every, split_out=split_out, shuffle=shuffle + arg=arg, + split_every=split_every, + split_out=split_out, + shuffle=shuffle, + **kwargs, ) @_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.agg") - def agg(self, arg, split_every=None, split_out=1, shuffle=None): + def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs): return self.aggregate( - arg, split_every=split_every, split_out=split_out, shuffle=shuffle + arg=arg, + split_every=split_every, + split_out=split_out, + shuffle=shuffle, + **kwargs, ) @@ -2583,22 +2630,39 @@ def nunique(self, split_every=None, split_out=1): ) @_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.aggregate") - def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): + def aggregate( + self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs + ): result = super().aggregate( - arg, split_every=split_every, split_out=split_out, shuffle=shuffle + arg=arg, + split_every=split_every, + split_out=split_out, + shuffle=shuffle, + **kwargs, ) if self._slice: - result = result[self._slice] + try: + result = result[self._slice] + except KeyError: + pass - if not isinstance(arg, (list, dict)) and isinstance(result, DataFrame): + if ( + arg is not None + and not isinstance(arg, (list, dict)) + and isinstance(result, DataFrame) + ): result = result[result.columns[0]] return result @_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.agg") - def agg(self, arg, split_every=None, split_out=1, shuffle=None): + def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs): return self.aggregate( - arg, split_every=split_every, split_out=split_out, shuffle=shuffle + arg=arg, + split_every=split_every, + split_out=split_out, + shuffle=shuffle, + **kwargs, ) @derived_from(pd.core.groupby.SeriesGroupBy) diff --git a/dask/dataframe/tests/test_groupby.py b/dask/dataframe/tests/test_groupby.py index d355afb6957..c667e5759e0 100644 --- a/dask/dataframe/tests/test_groupby.py +++ b/dask/dataframe/tests/test_groupby.py @@ -3,6 +3,7 @@ import operator import pickle import warnings +from functools import partial import numpy as np import pandas as pd @@ -14,6 +15,7 @@ from dask.dataframe._compat import ( PANDAS_GT_110, PANDAS_GT_130, + PANDAS_GT_140, PANDAS_GT_150, check_numeric_only_deprecation, tm, @@ -2896,6 +2898,47 @@ def agg(grp, **kwargs): ) +@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0") +@pytest.mark.parametrize("shuffle", [True, False]) +def test_dataframe_named_agg(shuffle): + df = pd.DataFrame( + { + "a": [1, 1, 2, 2], + "b": [1, 2, 5, 6], + "c": [6, 3, 6, 7], + } + ) + ddf = dd.from_pandas(df, npartitions=2) + + expected = df.groupby("a").agg( + x=pd.NamedAgg("b", aggfunc="sum"), + y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)), + ) + actual = ddf.groupby("a").agg( + shuffle=shuffle, + x=pd.NamedAgg("b", aggfunc="sum"), + y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)), + ) + assert_eq(expected, actual) + + +@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0") +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("agg", ["count", np.mean, partial(np.var, ddof=1)]) +def test_series_named_agg(shuffle, agg): + df = pd.DataFrame( + { + "a": [5, 4, 3, 5, 4, 2, 3, 2], + "b": [1, 2, 5, 6, 9, 2, 6, 8], + } + ) + ddf = dd.from_pandas(df, npartitions=2) + + expected = df.groupby("a").b.agg(c=agg, d="sum") + actual = ddf.groupby("a").b.agg(shuffle=shuffle, c=agg, d="sum") + assert_eq(expected, actual) + + def test_empty_partitions_with_value_counts(): # https://github.com/dask/dask/issues/7065 df = pd.DataFrame(