Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable named aggregation syntax #9563

Merged
merged 7 commits into from Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
174 changes: 119 additions & 55 deletions dask/dataframe/groupby.py
Expand Up @@ -10,7 +10,11 @@

from dask import config
from dask.base import tokenize
from dask.dataframe._compat import PANDAS_GT_150, check_numeric_only_deprecation
from dask.dataframe._compat import (
PANDAS_GT_140,
PANDAS_GT_150,
check_numeric_only_deprecation,
)
from dask.dataframe.core import (
GROUP_KEYS_DEFAULT,
DataFrame,
Expand All @@ -37,6 +41,9 @@
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M, _deprecated, derived_from, funcname, itemgetter

if PANDAS_GT_140:
from pandas.core.apply import reconstruct_func, validate_func_kwargs

# #############################################
#
# GroupBy implementation notes
Expand Down Expand Up @@ -1162,13 +1169,14 @@ def wrapper(func):
{based_on_str}
Parameters
----------
arg : callable, str, list or dict
arg : callable, str, list or dict, optional
Aggregation spec. Accepted combinations are:

- callable function
- string function name
- list of functions and/or function names, e.g. ``[np.sum, 'mean']``
- dict of column names -> function, function name or list of such.
- None only if named aggregation syntax is used
split_every : int, optional
Number of intermediate partitions that may be aggregated at once.
This defaults to 8. If your intermediate partitions are likely to
Expand All @@ -1185,6 +1193,12 @@ def wrapper(func):
``split_out = 1``. When ``split_out > 1``, it chooses the algorithm
set by the ``shuffle`` option in the dask config system, or ``"tasks"``
if nothing is set.
kwargs: tuple or pd.NamedAgg, optional
Used for named aggregations where the keywords are the output column
names and the values are tuples where the first element is the input
column name and the second element is the aggregation function.
``pandas.NamedAgg`` can also be used as the value. To use the named
aggregation syntax, arg must be set to None.
"""
return func

Expand Down Expand Up @@ -1881,7 +1895,9 @@ def get_group(self, key):
)

@_aggregate_docstring()
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
if split_out is None:
warnings.warn(
"split_out=None is deprecated, please use a positive integer, "
Expand All @@ -1891,7 +1907,20 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
split_out = 1
shuffle = _determine_shuffle(shuffle, split_out)

relabeling = None
columns = None
order = None
column_projection = None
if PANDAS_GT_140:
if isinstance(self, DataFrameGroupBy):
if arg is None:
relabeling, arg, columns, order = reconstruct_func(arg, **kwargs)

elif isinstance(self, SeriesGroupBy):
relabeling = arg is None
if relabeling:
columns, arg = validate_func_kwargs(kwargs)

if isinstance(self.obj, DataFrame):
if isinstance(self.by, tuple) or np.isscalar(self.by):
group_columns = {self.by}
Expand Down Expand Up @@ -1985,10 +2014,11 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
# for larger values of split_out. However, the shuffle
# step requires that the result of `chunk` produces a
# proper DataFrame type

# If we have a median in the spec, we cannot do an initial
# aggregation.
if has_median:
return _shuffle_aggregate(
result = _shuffle_aggregate(
chunk_args,
chunk=_non_agg_chunk,
chunk_kwargs={
Expand All @@ -2012,7 +2042,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
sort=self.sort,
)
else:
return _shuffle_aggregate(
result = _shuffle_aggregate(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs={
Expand All @@ -2035,49 +2065,56 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
shuffle=shuffle if isinstance(shuffle, str) else "tasks",
sort=self.sort,
)
else:
if self.sort is None and split_out > 1:
warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning)

# Check sort behavior
if self.sort and split_out > 1:
raise NotImplementedError(
"Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`"
" Try using `shuffle=True` if you are grouping on a single column."
" Otherwise, try using split_out=1, or grouping with sort=False."
)

if self.sort is None and split_out > 1:
warnings.warn(SORT_SPLIT_OUT_WARNING, FutureWarning)

# Check sort behavior
if self.sort and split_out > 1:
raise NotImplementedError(
"Cannot guarantee sorted keys for `split_out>1` and `shuffle=False`"
" Try using `shuffle=True` if you are grouping on a single column."
" Otherwise, try using split_out=1, or grouping with sort=False."
result = aca(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs=dict(
funcs=chunk_funcs,
sort=False,
**self.observed,
**self.dropna,
),
combine=_groupby_apply_funcs,
combine_kwargs=dict(
funcs=aggregate_funcs,
level=levels,
sort=False,
**self.observed,
**self.dropna,
),
aggregate=_agg_finalize,
aggregate_kwargs=dict(
aggregate_funcs=aggregate_funcs,
finalize_funcs=finalizers,
level=levels,
**self.observed,
**self.dropna,
),
token="aggregate",
split_every=split_every,
split_out=split_out,
split_out_setup=split_out_on_index,
sort=self.sort,
)

return aca(
chunk_args,
chunk=_groupby_apply_funcs,
chunk_kwargs=dict(
funcs=chunk_funcs,
sort=False,
**self.observed,
**self.dropna,
),
combine=_groupby_apply_funcs,
combine_kwargs=dict(
funcs=aggregate_funcs,
level=levels,
sort=False,
**self.observed,
**self.dropna,
),
aggregate=_agg_finalize,
aggregate_kwargs=dict(
aggregate_funcs=aggregate_funcs,
finalize_funcs=finalizers,
level=levels,
**self.observed,
**self.dropna,
),
token="aggregate",
split_every=split_every,
split_out=split_out,
split_out_setup=split_out_on_index,
sort=self.sort,
)
if relabeling and result is not None:
if order is not None:
result = result.iloc[:, order]
result.columns = columns

return result

@insert_meta_param_description(pad=12)
def apply(self, func, *args, **kwargs):
Expand Down Expand Up @@ -2501,18 +2538,28 @@ def __getattr__(self, key):
raise AttributeError(e) from e

@_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.aggregate")
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
if arg == "size":
return self.size()

return super().aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)

@_aggregate_docstring(based_on="pd.core.groupby.DataFrameGroupBy.agg")
def agg(self, arg, split_every=None, split_out=1, shuffle=None):
def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs):
return self.aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)


Expand Down Expand Up @@ -2583,22 +2630,39 @@ def nunique(self, split_every=None, split_out=1):
)

@_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.aggregate")
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs
):
result = super().aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)
if self._slice:
result = result[self._slice]
try:
result = result[self._slice]
except KeyError:
pass

if not isinstance(arg, (list, dict)) and isinstance(result, DataFrame):
if (
arg is not None
and not isinstance(arg, (list, dict))
and isinstance(result, DataFrame)
):
result = result[result.columns[0]]

return result

@_aggregate_docstring(based_on="pd.core.groupby.SeriesGroupBy.agg")
def agg(self, arg, split_every=None, split_out=1, shuffle=None):
def agg(self, arg=None, split_every=None, split_out=1, shuffle=None, **kwargs):
return self.aggregate(
arg, split_every=split_every, split_out=split_out, shuffle=shuffle
arg=arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
**kwargs,
)

@derived_from(pd.core.groupby.SeriesGroupBy)
Expand Down
43 changes: 43 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Expand Up @@ -3,6 +3,7 @@
import operator
import pickle
import warnings
from functools import partial

import numpy as np
import pandas as pd
Expand All @@ -14,6 +15,7 @@
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_130,
PANDAS_GT_140,
PANDAS_GT_150,
check_numeric_only_deprecation,
tm,
Expand Down Expand Up @@ -2896,6 +2898,47 @@ def agg(grp, **kwargs):
)


@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0")
@pytest.mark.parametrize("shuffle", [True, False])
def test_dataframe_named_agg(shuffle):
df = pd.DataFrame(
{
"a": [1, 1, 2, 2],
"b": [1, 2, 5, 6],
"c": [6, 3, 6, 7],
}
)
ddf = dd.from_pandas(df, npartitions=2)

expected = df.groupby("a").agg(
x=pd.NamedAgg("b", aggfunc="sum"),
y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)),
)
actual = ddf.groupby("a").agg(
shuffle=shuffle,
x=pd.NamedAgg("b", aggfunc="sum"),
y=pd.NamedAgg("c", aggfunc=partial(np.std, ddof=1)),
)
assert_eq(expected, actual)


@pytest.mark.skipif(not PANDAS_GT_140, reason="requires pandas >= 1.4.0")
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("agg", ["count", np.mean, partial(np.var, ddof=1)])
def test_series_named_agg(shuffle, agg):
df = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(df, npartitions=2)

expected = df.groupby("a").b.agg(c=agg, d="sum")
actual = ddf.groupby("a").b.agg(shuffle=shuffle, c=agg, d="sum")
assert_eq(expected, actual)


def test_empty_partitions_with_value_counts():
# https://github.com/dask/dask/issues/7065
df = pd.DataFrame(
Expand Down