Skip to content

Commit

Permalink
Partial functions in aggs may have arguments (#9724)
Browse files Browse the repository at this point in the history
  • Loading branch information
j-bennet committed Dec 9, 2022
1 parent 0885125 commit 159948e
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 12 deletions.
68 changes: 56 additions & 12 deletions dask/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import operator
import uuid
import warnings
from functools import partial
from numbers import Integral

import numpy as np
Expand Down Expand Up @@ -815,7 +816,13 @@ def _build_agg_args(spec):
applied after the ``agg_funcs``. They are used to create final results
from intermediate representations.
"""
known_np_funcs = {np.min: "min", np.max: "max", np.median: "median"}
known_np_funcs = {
np.min: "min",
np.max: "max",
np.median: "median",
np.std: "std",
np.var: "var",
}

# check that there are no name conflicts for a single input column
by_name = {}
Expand All @@ -831,11 +838,20 @@ def _build_agg_args(spec):
aggs = {}
finalizers = []

# a partial may contain some arguments, pass them down
# https://github.com/dask/dask/issues/9615
for (result_column, func, input_column) in spec:
func_args = ()
func_kwargs = {}
if isinstance(func, partial):
func_args, func_kwargs = func.args, func.keywords

if not isinstance(func, Aggregation):
func = funcname(known_np_funcs.get(func, func))

impls = _build_agg_args_single(result_column, func, input_column)
impls = _build_agg_args_single(
result_column, func, func_args, func_kwargs, input_column
)

# overwrite existing result-columns, generate intermediates only once
for spec in impls["chunk_funcs"]:
Expand All @@ -851,7 +867,7 @@ def _build_agg_args(spec):
return chunks, aggs, finalizers


def _build_agg_args_single(result_column, func, input_column):
def _build_agg_args_single(result_column, func, func_args, func_kwargs, input_column):
simple_impl = {
"sum": (M.sum, M.sum),
"min": (M.min, M.min),
Expand All @@ -873,10 +889,14 @@ def _build_agg_args_single(result_column, func, input_column):
)

elif func == "var":
return _build_agg_args_var(result_column, func, input_column)
return _build_agg_args_var(
result_column, func, func_args, func_kwargs, input_column
)

elif func == "std":
return _build_agg_args_std(result_column, func, input_column)
return _build_agg_args_std(
result_column, func, func_args, func_kwargs, input_column
)

elif func == "mean":
return _build_agg_args_mean(result_column, func, input_column)
Expand Down Expand Up @@ -914,11 +934,25 @@ def _build_agg_args_simple(result_column, func, input_column, impl_pair):
)


def _build_agg_args_var(result_column, func, input_column):
def _build_agg_args_var(result_column, func, func_args, func_kwargs, input_column):
int_sum = _make_agg_id("sum", input_column)
int_sum2 = _make_agg_id("sum2", input_column)
int_count = _make_agg_id("count", input_column)

# we don't expect positional args here
if func_args:
raise TypeError(
f"aggregate function '{func}' doesn't support positional arguments, but got {func_args}"
)

# and we only expect ddof=N in kwargs
expected_kwargs = {"ddof"}
unexpected_kwargs = func_kwargs.keys() - expected_kwargs
if unexpected_kwargs:
raise TypeError(
f"aggregate function '{func}' supports {expected_kwargs} keyword arguments, but got {unexpected_kwargs}"
)

return dict(
chunk_funcs=[
(int_sum, _apply_func_to_column, dict(column=input_column, func=M.sum)),
Expand All @@ -932,13 +966,20 @@ def _build_agg_args_var(result_column, func, input_column):
finalizer=(
result_column,
_finalize_var,
dict(sum_column=int_sum, count_column=int_count, sum2_column=int_sum2),
dict(
sum_column=int_sum,
count_column=int_count,
sum2_column=int_sum2,
**func_kwargs,
),
),
)


def _build_agg_args_std(result_column, func, input_column):
impls = _build_agg_args_var(result_column, func, input_column)
def _build_agg_args_std(result_column, func, func_args, func_kwargs, input_column):
impls = _build_agg_args_var(
result_column, func, func_args, func_kwargs, input_column
)

result_column, _, kwargs = impls["finalizer"]
impls["finalizer"] = (result_column, _finalize_std, kwargs)
Expand Down Expand Up @@ -1118,7 +1159,10 @@ def _finalize_mean(df, sum_column, count_column):
return df[sum_column] / df[count_column]


def _finalize_var(df, count_column, sum_column, sum2_column, ddof=1):
def _finalize_var(df, count_column, sum_column, sum2_column, **kwargs):
# arguments are being checked when building the finalizer. As of this moment,
# we're only using ddof, and raising an error on other keyword args.
ddof = kwargs.get("ddof", 1)
n = df[count_column]
x = df[sum_column]
x2 = df[sum2_column]
Expand All @@ -1132,8 +1176,8 @@ def _finalize_var(df, count_column, sum_column, sum2_column, ddof=1):
return result


def _finalize_std(df, count_column, sum_column, sum2_column, ddof=1):
result = _finalize_var(df, count_column, sum_column, sum2_column, ddof)
def _finalize_std(df, count_column, sum_column, sum2_column, **kwargs):
result = _finalize_var(df, count_column, sum_column, sum2_column, **kwargs)
return np.sqrt(result)


Expand Down
79 changes: 79 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,85 @@ def test_groupby_aggregate_categoricals(grouping, agg):
assert_eq(agg(grouping(pdf)["value"]), agg(grouping(ddf)["value"]))


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, ddof=1)),
lambda grp: grp.agg(partial(np.std, ddof=-2)),
lambda grp: grp.agg(partial(np.var, ddof=1)),
lambda grp: grp.agg(partial(np.var, ddof=-2)),
],
)
def test_groupby_aggregate_partial_function(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

# DataFrameGroupBy
assert_eq(agg(pdf.groupby("a")), agg(ddf.groupby("a")))

# SeriesGroupBy
assert_eq(agg(pdf.groupby("a")["b"]), agg(ddf.groupby("a")["b"]))


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, unexpected_arg=1)),
lambda grp: grp.agg(partial(np.var, unexpected_arg=1)),
],
)
def test_groupby_aggregate_partial_function_unexpected_kwargs(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

with pytest.raises(
TypeError,
match="supports {'ddof'} keyword arguments, but got {'unexpected_arg'}",
):
agg(ddf.groupby("a"))

# SeriesGroupBy
with pytest.raises(
TypeError,
match="supports {'ddof'} keyword arguments, but got {'unexpected_arg'}",
):
agg(ddf.groupby("a")["b"])


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, "positional_arg")),
lambda grp: grp.agg(partial(np.var, "positional_arg")),
],
)
def test_groupby_aggregate_partial_function_unexpected_args(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

with pytest.raises(TypeError, match="doesn't support positional arguments"):
agg(ddf.groupby("a"))

# SeriesGroupBy
with pytest.raises(TypeError, match="doesn't support positional arguments"):
agg(ddf.groupby("a")["b"])


@pytest.mark.xfail(
not dask.dataframe.utils.PANDAS_GT_110,
reason="dropna kwarg not supported in pandas < 1.1.0.",
Expand Down

0 comments on commit 159948e

Please sign in to comment.