Skip to content

Commit

Permalink
Enable named aggregation syntax (#9563)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisJar committed Oct 24, 2022
1 parent c8dc395 commit 9adabf1
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 55 deletions.
174 changes: 119 additions & 55 deletions dask/dataframe/groupby.py
Expand Up @@ -10,7 +10,11 @@

from dask import config
from dask.base import tokenize
from dask.dataframe._compat import PANDAS_GT_150, check_numeric_only_deprecation
from dask.dataframe._compat import (
PANDAS_GT_140,
PANDAS_GT_150,
check_numeric_only_deprecation,
)
from dask.dataframe.core import (
GROUP_KEYS_DEFAULT,
DataFrame,
Expand All @@ -37,6 +41,9 @@
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M, _deprecated, derived_from, funcname, itemgetter

if PANDAS_GT_140:
from pandas.core.apply import reconstruct_func, validate_func_kwargs

# #############################################
#
# GroupBy implementation notes
Expand Down Expand Up @@ -1162,13 +1169,14 @@ def wrapper(func):
{based_on_str}
Parameters
----------
arg : callable, str, list or dict
arg : callable, str, list or dict, optional
Aggregation spec. Accepted combinations are:
- callable function
- string function name
- list of functions and/or function names, e.g. ``[np.sum, 'mean']``
- dict of column names -> function, function name or list of such.
- None only if named aggregation syntax is used
split_every : int, optional
Number of intermediate partitions that may be aggregated at once.
This defaults to 8. If your intermediate partitions are likely to
Expand All @@ -1185,6 +1193,12 @@ def wrapper(func):
``split_out = 1``. When ``split_out > 1``, it chooses the algorithm
set by the ``shuffle`` option in the dask config system, or ``"tasks"``
if nothing is set.
kwargs: tuple or pd.NamedAgg, optional
Used for named aggregations where the keywords are the output column
names and the values are tuples where the first element is the input
column name and the second element is the aggregation function.
``pandas.NamedAgg`` can also be used as the value. To use the named
aggregation syntax, arg must be set to None.
"""
return func

Expand Down Expand Up @@ -1881,7 +1895,9 @@ def get_group(self, key):
)

@_aggregate_docstring()
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
if split_out is None:
warnings.warn(
"split_out=None is deprecated, please use a positive integer, "
Expand All @@ -1891,7 +1907,20 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
split_out = 1
shuffle = _determine_shuffle(shuffle, split_out)

relabeling = None
columns = None
order = None
column_projection = None
if PANDAS_GT_140:
if isinstance(self, DataFrameGroupBy):
if arg is None:
relabeling, arg, columns, order = reconstruct_func(arg, **kwargs)

elif isinstance(self, SeriesGroupBy):
relabeling = arg is None
if relabeling:
columns, arg = validate_func_kwargs(kwargs)

if isinstance(self.obj, DataFrame):
if isinstance(self.by, tuple) or np.isscalar(self.by):
group_columns = {self.by}
Expand Down Expand Up @@ -1985,10 +2014,11 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
# for larger values of split_out. However, the shuffle
# step requires that the result of `chunk` produces a
# proper DataFrame type

# If we have a median in the spec, we cannot do an initial
# aggregation.
if has_median:
return _shuffle_aggregate(
result = _shuffle_aggregate(
chunk_args,
chunk=_non_agg_chunk,
chunk_kwargs={
Expand All @@ -2012,7 +2042,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
sort=self.sort,
)
else:
return _shuffle_aggregate(
result = _shuffle_aggregate(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs={
Expand All @@ -2035,49 +2065,56 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
shuffle=shuffle if isinstance(shuffle, str) else "tasks",
sort=self.sort,
)
else:
if self.sort is None and split_out > 1:
warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning)

# Check sort behavior
if self.sort and split_out > 1:
raise NotImplementedError(
"Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`"
" Try using `shuffle=True` if you are grouping on a single column."
" Otherwise, try using split_out=1, or grouping with sort=False."
)

if self.sort is None and split_out > 1:
warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning)

# Check sort behavior
if self.sort and split_out > 1:
raise NotImplementedError(
"Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`"
" Try using `shuffle=True` if you are grouping on a single column."
" Otherwise, try using split_out=1, or grouping with sort=False."
result = aca(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs=dict(
funcs=chunk_funcs,
sort=False,
**self.observed,
**self.dropna,
),
combine=_groupby_apply_funcs,
combine_kwargs=dict(
funcs=aggregate_funcs,
level=levels,
sort=False,
**self.observed,
**self.dropna,
),
aggregate=_agg_finalize,
aggregate_kwargs=dict(
aggregate_funcs=aggregate_funcs,
finalize_funcs=finalizers,
level=levels,
**self.observed,
**self.dropna,
),
token="aggregate",
split_every=split_every,
split_out=split_out,
split_out_setup=split_out_on_index,
sort=self.sort,
)

return aca(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs=dict(
funcs=chunk_funcs,
sort=False,
**self.observed,
**self.dropna,
),
combine=_groupby_apply_funcs,
combine_kwargs=dict(
funcs=aggregate_funcs,
level=levels,
sort=False,
**self.observed,
**self.dropna,
),
aggregate=_agg_finalize,
aggregate_kwargs=dict(
aggregate_funcs=aggregate_funcs,
finalize_funcs=finalizers,
level=levels,
**self.observed,
**self.dropna,
),
token="aggregate",
split_every=split_every,
split_out=split_out,
split_out_setup=split_out_on_index,
sort=self.sort,
)
if relabeling and result is not None:
if order is not None:
result = result.iloc[:, order]
result.columns = columns

return result

@insert_meta_param_description(pad=12)
def apply(self, func, *args, **kwargs):
Expand Down Expand Up @@ -2501,18 +2538,28 @@ def __getattr__(self, key):
raise AttributeError(e) from e

@_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.aggregate")
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
if arg == "size":
return self.size()

return super().aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)

@_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.agg")
def agg(self, arg, split_every=None, split_out=1, shuffle=None):
def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs):
return self.aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)


Expand Down Expand Up @@ -2583,22 +2630,39 @@ def nunique(self, split_every=None, split_out=1):
)

@_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.aggregate")
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
result = super().aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)
if self._slice:
result = result[self._slice]
try:
result = result[self._slice]
except KeyError:
pass

if not isinstance(arg, (list, dict)) and isinstance(result, DataFrame):
if (
arg is not None
and not isinstance(arg, (list, dict))
and isinstance(result, DataFrame)
):
result = result[result.columns[0]]

return result

@_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.agg")
def agg(self, arg, split_every=None, split_out=1, shuffle=None):
def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs):
return self.aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)

@derived_from(pd.core.groupby.SeriesGroupBy)
Expand Down
43 changes: 43 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Expand Up @@ -3,6 +3,7 @@
import operator
import pickle
import warnings
from functools import partial

import numpy as np
import pandas as pd
Expand All @@ -14,6 +15,7 @@
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_130,
PANDAS_GT_140,
PANDAS_GT_150,
check_numeric_only_deprecation,
tm,
Expand Down Expand Up @@ -2896,6 +2898,47 @@ def agg(grp, **kwargs):
)


@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0")
@pytest.mark.parametrize("shuffle", [True, False])
def test_dataframe_named_agg(shuffle):
df = pd.DataFrame(
{
"a": [1, 1, 2, 2],
"b": [1, 2, 5, 6],
"c": [6, 3, 6, 7],
}
)
ddf = dd.from_pandas(df, npartitions=2)

expected = df.groupby("a").agg(
x=pd.NamedAgg("b", aggfunc="sum"),
y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)),
)
actual = ddf.groupby("a").agg(
shuffle=shuffle,
x=pd.NamedAgg("b", aggfunc="sum"),
y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)),
)
assert_eq(expected, actual)


@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0")
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("agg", ["count", np.mean, partial(np.var, ddof=1)])
def test_series_named_agg(shuffle, agg):
df = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(df, npartitions=2)

expected = df.groupby("a").b.agg(c=agg, d="sum")
actual = ddf.groupby("a").b.agg(shuffle=shuffle, c=agg, d="sum")
assert_eq(expected, actual)


def test_empty_partitions_with_value_counts():
# https://github.com/dask/dask/issues/7065
df = pd.DataFrame(
Expand Down

0 comments on commit 9adabf1

Please sign in to comment.