Skip to content

Commit

Permalink
REGR: Better warning in pivot_table when dropping nuisance columns (p…
Browse files Browse the repository at this point in the history
…andas-dev#49615)

* REGR: Better warning in pivot_table when dropping nuisance columns

* type-hint fixups

(cherry picked from commit ab89c53)
  • Loading branch information
rhshadrach committed Nov 11, 2022
1 parent 1b8e4db commit 8d80e0b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 5 deletions.
14 changes: 13 additions & 1 deletion pandas/core/reshape/pivot.py
Expand Up @@ -21,6 +21,7 @@
Substitution,
deprecate_nonkeyword_arguments,
)
from pandas.util._exceptions import rewrite_warning

from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -163,7 +164,18 @@ def __internal_pivot_table(
values = list(values)

grouped = data.groupby(keys, observed=observed, sort=sort)
agged = grouped.agg(aggfunc)
msg = (
"pivot_table dropped a column because it failed to aggregate. This behavior "
"is deprecated and will raise in a future version of pandas. Select only the "
"columns that can be aggregated."
)
with rewrite_warning(
target_message="The default value of numeric_only",
target_category=FutureWarning,
new_message=msg,
):
agged = grouped.agg(aggfunc)

if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
agged = agged.dropna(how="all")

Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/reshape/test_pivot.py
Expand Up @@ -146,7 +146,7 @@ def test_pivot_table_nocols(self):
df = DataFrame(
{"rows": ["a", "b", "c"], "cols": ["x", "y", "z"], "values": [1, 2, 3]}
)
msg = "The default value of numeric_only"
msg = "pivot_table dropped a column because it failed to aggregate"
with tm.assert_produces_warning(FutureWarning, match=msg):
rs = df.pivot_table(columns="cols", aggfunc=np.sum)
xp = df.pivot_table(index="cols", aggfunc=np.sum).T
Expand Down Expand Up @@ -907,7 +907,7 @@ def test_no_col(self):

# to help with a buglet
self.data.columns = [k * 2 for k in self.data.columns]
msg = "The default value of numeric_only"
msg = "pivot_table dropped a column because it failed to aggregate"
with tm.assert_produces_warning(FutureWarning, match=msg):
table = self.data.pivot_table(
index=["AA", "BB"], margins=True, aggfunc=np.mean
Expand Down Expand Up @@ -975,7 +975,7 @@ def test_margin_with_only_columns_defined(
}
)

msg = "The default value of numeric_only"
msg = "pivot_table dropped a column because it failed to aggregate"
with tm.assert_produces_warning(FutureWarning, match=msg):
result = df.pivot_table(columns=columns, margins=True, aggfunc=aggfunc)
expected = DataFrame(values, index=Index(["D", "E"]), columns=expected_columns)
Expand Down Expand Up @@ -2004,7 +2004,7 @@ def test_pivot_string_func_vs_func(self, f, f_numpy):
# GH #18713
# for consistency purposes

msg = "The default value of numeric_only"
msg = "pivot_table dropped a column because it failed to aggregate"
with tm.assert_produces_warning(FutureWarning, match=msg):
result = pivot_table(self.data, index="A", columns="B", aggfunc=f)
expected = pivot_table(self.data, index="A", columns="B", aggfunc=f_numpy)
Expand Down
39 changes: 39 additions & 0 deletions pandas/tests/util/test_rewrite_warning.py
@@ -0,0 +1,39 @@
import warnings

import pytest

from pandas.util._exceptions import rewrite_warning

import pandas._testing as tm


@pytest.mark.parametrize(
"target_category, target_message, hit",
[
(FutureWarning, "Target message", True),
(FutureWarning, "Target", True),
(FutureWarning, "get mess", True),
(FutureWarning, "Missed message", False),
(DeprecationWarning, "Target message", False),
],
)
@pytest.mark.parametrize(
"new_category",
[
None,
DeprecationWarning,
],
)
def test_rewrite_warning(target_category, target_message, hit, new_category):
new_message = "Rewritten message"
if hit:
expected_category = new_category if new_category else target_category
expected_message = new_message
else:
expected_category = FutureWarning
expected_message = "Target message"
with tm.assert_produces_warning(expected_category, match=expected_message):
with rewrite_warning(
target_message, target_category, new_message, new_category
):
warnings.warn(message="Target message", category=FutureWarning)
45 changes: 45 additions & 0 deletions pandas/util/_exceptions.py
Expand Up @@ -3,7 +3,9 @@
import contextlib
import inspect
import os
import re
from typing import Iterator
import warnings


@contextlib.contextmanager
Expand Down Expand Up @@ -47,3 +49,46 @@ def find_stack_level() -> int:
else:
break
return n


@contextlib.contextmanager
def rewrite_warning(
target_message: str,
target_category: type[Warning],
new_message: str,
new_category: type[Warning] | None = None,
) -> Iterator[None]:
"""
Rewrite the message of a warning.
Parameters
----------
target_message : str
Warning message to match.
target_category : Warning
Warning type to match.
new_message : str
New warning message to emit.
new_category : Warning or None, default None
New warning type to emit. When None, will be the same as target_category.
"""
if new_category is None:
new_category = target_category
with warnings.catch_warnings(record=True) as record:
yield
if len(record) > 0:
match = re.compile(target_message)
for warning in record:
if warning.category is target_category and re.search(
match, str(warning.message)
):
category = new_category
message: Warning | str = new_message
else:
category, message = warning.category, warning.message
warnings.warn_explicit(
message=message,
category=category,
filename=warning.filename,
lineno=warning.lineno,
)

0 comments on commit 8d80e0b

Please sign in to comment.