Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial functions in aggs may have arguments #9724

Merged
merged 8 commits into from
Dec 9, 2022
68 changes: 56 additions & 12 deletions dask/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import operator
import uuid
import warnings
from functools import partial
from numbers import Integral

import numpy as np
Expand Down Expand Up @@ -815,7 +816,13 @@ def _build_agg_args(spec):
applied after the ``agg_funcs``. They are used to create final results
from intermediate representations.
"""
known_np_funcs = {np.min: "min", np.max: "max", np.median: "median"}
known_np_funcs = {
np.min: "min",
np.max: "max",
np.median: "median",
np.std: "std",
np.var: "var",
}

# check that there are no name conflicts for a single input column
by_name = {}
Expand All @@ -831,11 +838,20 @@ def _build_agg_args(spec):
aggs = {}
finalizers = []

# a partial may contain some arguments, pass then down
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/dask/dask/issues/9615
for (result_column, func, input_column) in spec:
func_args = ()
func_kwargs = {}
if isinstance(func, partial):
func_args, func_kwargs = func.args, func.keywords

if not isinstance(func, Aggregation):
func = funcname(known_np_funcs.get(func, func))

impls = _build_agg_args_single(result_column, func, input_column)
impls = _build_agg_args_single(
result_column, func, func_args, func_kwargs, input_column
)

# overwrite existing result-columns, generate intermediates only once
for spec in impls["chunk_funcs"]:
Expand All @@ -851,7 +867,7 @@ def _build_agg_args(spec):
return chunks, aggs, finalizers


def _build_agg_args_single(result_column, func, input_column):
def _build_agg_args_single(result_column, func, func_args, func_kwargs, input_column):
simple_impl = {
"sum": (M.sum, M.sum),
"min": (M.min, M.min),
Expand All @@ -873,10 +889,14 @@ def _build_agg_args_single(result_column, func, input_column):
)

elif func == "var":
return _build_agg_args_var(result_column, func, input_column)
return _build_agg_args_var(
result_column, func, func_args, func_kwargs, input_column
)

elif func == "std":
return _build_agg_args_std(result_column, func, input_column)
return _build_agg_args_std(
result_column, func, func_args, func_kwargs, input_column
)

elif func == "mean":
return _build_agg_args_mean(result_column, func, input_column)
Expand Down Expand Up @@ -914,11 +934,25 @@ def _build_agg_args_simple(result_column, func, input_column, impl_pair):
)


def _build_agg_args_var(result_column, func, input_column):
def _build_agg_args_var(result_column, func, func_args, func_kwargs, input_column):
int_sum = _make_agg_id("sum", input_column)
int_sum2 = _make_agg_id("sum2", input_column)
int_count = _make_agg_id("count", input_column)

# we don't expect positional args here
if func_args:
raise TypeError(
f"aggregate function '{func}' got unexpected positional arguments {func_args}"
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
)

# and we only expect ddof=N in kwargs
expected_kwargs = {"ddof"}
unexpected_kwargs = set(func_kwargs.keys()) - expected_kwargs
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
for arg in unexpected_kwargs:
raise TypeError(
f"aggregate function '{func}' got an unexpected keyword argument '{arg}'"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multiple unexpected kwargs, it'd probably be a bit more informative to just let users see all the invalid kwargs at once. Otherwise, it may take some iteration to get rid of them all.

Suggested change
for arg in unexpected_kwargs:
raise TypeError(
f"aggregate function '{func}' got an unexpected keyword argument '{arg}'"
)
if unexpected_kwargs:
raise TypeError(
f"The aggregate function '{func}' supports {expected_kwargs} keyword arguments, but got {unexpected_kwargs}"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we'll need to make a corresponding change to test_groupby_aggregate_partial_function_unexpected_kwargs since we're updating the error message


return dict(
chunk_funcs=[
(int_sum, _apply_func_to_column, dict(column=input_column, func=M.sum)),
Expand All @@ -932,13 +966,20 @@ def _build_agg_args_var(result_column, func, input_column):
finalizer=(
result_column,
_finalize_var,
dict(sum_column=int_sum, count_column=int_count, sum2_column=int_sum2),
dict(
sum_column=int_sum,
count_column=int_count,
sum2_column=int_sum2,
**func_kwargs,
),
),
)


def _build_agg_args_std(result_column, func, input_column):
impls = _build_agg_args_var(result_column, func, input_column)
def _build_agg_args_std(result_column, func, func_args, func_kwargs, input_column):
impls = _build_agg_args_var(
result_column, func, func_args, func_kwargs, input_column
)

result_column, _, kwargs = impls["finalizer"]
impls["finalizer"] = (result_column, _finalize_std, kwargs)
Expand Down Expand Up @@ -1118,7 +1159,10 @@ def _finalize_mean(df, sum_column, count_column):
return df[sum_column] / df[count_column]


def _finalize_var(df, count_column, sum_column, sum2_column, ddof=1):
def _finalize_var(df, count_column, sum_column, sum2_column, **kwargs):
# arguments are being checked when building the finalizer. As of this moment,
# we're only using ddof, and raising an error on other keyword args.
ddof = kwargs.get("ddof", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC it looks like even though we're forwarding all kwargs specified by the user, we're only using ddof. Is that the case? If so, it'd be good to raise an error if a user passes in a kwarg that isn't supported. Rather raise an informative error than silently ignore the kwarg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is one of my TODOs and a reason this PR is a draft. I have to see what else I need to pass down, and figure out how to handle the rest. I think I need to see what Pandas does if I pass unexpected parameters, and match the behavior.

Copy link
Contributor Author

@j-bennet j-bennet Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added handling for unexpected args, and a test for that. numeric_only kwarg is an interesting one - Pandas supports it, and I'm not sure why Dask doesn't, and whether it should.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numeric_only is a tricky one, as pandas behavior is changing. See #9471 (comment) . For now we have been filtering the warnings that pandas raises, see this PR #9496 .

@jrbourbeau I can't recall where we are in the discussion of this kwarg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ncclementi Thank you, I looked over the PRs, and it looks to me that numeric_only would have to be tackled separately, since it would be a pretty large change, and there's no clearly defined path yet.

n = df[count_column]
x = df[sum_column]
x2 = df[sum2_column]
Expand All @@ -1132,8 +1176,8 @@ def _finalize_var(df, count_column, sum_column, sum2_column, ddof=1):
return result


def _finalize_std(df, count_column, sum_column, sum2_column, ddof=1):
result = _finalize_var(df, count_column, sum_column, sum2_column, ddof)
def _finalize_std(df, count_column, sum_column, sum2_column, **kwargs):
result = _finalize_var(df, count_column, sum_column, sum2_column, **kwargs)
return np.sqrt(result)


Expand Down
73 changes: 73 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,79 @@ def test_groupby_aggregate_categoricals(grouping, agg):
assert_eq(agg(grouping(pdf)["value"]), agg(grouping(ddf)["value"]))


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, ddof=1)),
lambda grp: grp.agg(partial(np.std, ddof=-2)),
lambda grp: grp.agg(partial(np.var, ddof=1)),
lambda grp: grp.agg(partial(np.var, ddof=-2)),
],
)
def test_groupby_aggregate_partial_function(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

# DataFrameGroupBy
assert_eq(agg(pdf.groupby("a")), agg(ddf.groupby("a")))

# SeriesGroupBy
assert_eq(agg(pdf.groupby("a")["b"]), agg(ddf.groupby("a")["b"]))


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, unexpected_arg=1)),
lambda grp: grp.agg(partial(np.var, unexpected_arg=1)),
],
)
def test_groupby_aggregate_partial_function_unexpected_kwargs(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

with pytest.raises(TypeError, match="unexpected keyword argument"):
agg(ddf.groupby("a")).compute()
j-bennet marked this conversation as resolved.
Show resolved Hide resolved

# SeriesGroupBy
with pytest.raises(TypeError, match="unexpected keyword argument"):
agg(ddf.groupby("a")["b"]).compute()
j-bennet marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"agg",
[
lambda grp: grp.agg(partial(np.std, "positional_arg")),
lambda grp: grp.agg(partial(np.var, "positional_arg")),
],
)
def test_groupby_aggregate_partial_function_unexpected_args(agg):
pdf = pd.DataFrame(
{
"a": [5, 4, 3, 5, 4, 2, 3, 2],
"b": [1, 2, 5, 6, 9, 2, 6, 8],
}
)
ddf = dd.from_pandas(pdf, npartitions=2)

with pytest.raises(TypeError, match="unexpected positional arguments"):
agg(ddf.groupby("a")).compute()
j-bennet marked this conversation as resolved.
Show resolved Hide resolved

# SeriesGroupBy
with pytest.raises(TypeError, match="unexpected positional arguments"):
agg(ddf.groupby("a")["b"]).compute()
j-bennet marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.xfail(
not dask.dataframe.utils.PANDAS_GT_110,
reason="dropna kwarg not supported in pandas < 1.1.0.",
Expand Down