From 0533e094364a5b0889b173475d2ed55f63bec25b Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Thu, 10 Nov 2022 15:32:43 -0500 Subject: [PATCH] REGR: Better warning in pivot_table when dropping nuisance columns (#49615) * REGR: Better warning in pivot_table when dropping nuisance columns * type-hint fixups --- pandas/core/reshape/pivot.py | 14 ++++++- pandas/tests/reshape/test_pivot.py | 8 ++-- pandas/tests/util/test_rewrite_warning.py | 39 ++++++++++++++++++++ pandas/util/_exceptions.py | 45 +++++++++++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 pandas/tests/util/test_rewrite_warning.py diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 37e78c7dbf7a2..810a428098df2 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -21,6 +21,7 @@ Appender, Substitution, ) +from pandas.util._exceptions import rewrite_warning from pandas.core.dtypes.cast import maybe_downcast_to_dtype from pandas.core.dtypes.common import ( @@ -164,7 +165,18 @@ def __internal_pivot_table( values = list(values) grouped = data.groupby(keys, observed=observed, sort=sort) - agged = grouped.agg(aggfunc) + msg = ( + "pivot_table dropped a column because it failed to aggregate. This behavior " + "is deprecated and will raise in a future version of pandas. Select only the " + "columns that can be aggregated." + ) + with rewrite_warning( + target_message="The default value of numeric_only", + target_category=FutureWarning, + new_message=msg, + ): + agged = grouped.agg(aggfunc) + if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns): agged = agged.dropna(how="all") diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py index 14ea670fa6cf9..f9119ea43160b 100644 --- a/pandas/tests/reshape/test_pivot.py +++ b/pandas/tests/reshape/test_pivot.py @@ -147,7 +147,7 @@ def test_pivot_table_nocols(self): df = DataFrame( {"rows": ["a", "b", "c"], "cols": ["x", "y", "z"], "values": [1, 2, 3]} ) - msg = "The default value of numeric_only" + msg = "pivot_table dropped a column because it failed to aggregate" with tm.assert_produces_warning(FutureWarning, match=msg): rs = df.pivot_table(columns="cols", aggfunc=np.sum) xp = df.pivot_table(index="cols", aggfunc=np.sum).T @@ -911,7 +911,7 @@ def test_no_col(self, data): # to help with a buglet data.columns = [k * 2 for k in data.columns] - msg = "The default value of numeric_only" + msg = "pivot_table dropped a column because it failed to aggregate" with tm.assert_produces_warning(FutureWarning, match=msg): table = data.pivot_table(index=["AA", "BB"], margins=True, aggfunc=np.mean) for value_col in table.columns: @@ -975,7 +975,7 @@ def test_margin_with_only_columns_defined( } ) - msg = "The default value of numeric_only" + msg = "pivot_table dropped a column because it failed to aggregate" with tm.assert_produces_warning(FutureWarning, match=msg): result = df.pivot_table(columns=columns, margins=True, aggfunc=aggfunc) expected = DataFrame(values, index=Index(["D", "E"]), columns=expected_columns) @@ -2004,7 +2004,7 @@ def test_pivot_string_func_vs_func(self, f, f_numpy, data): # GH #18713 # for consistency purposes - msg = "The default value of numeric_only" + msg = "pivot_table dropped a column because it failed to aggregate" with tm.assert_produces_warning(FutureWarning, match=msg): result = pivot_table(data, index="A", columns="B", aggfunc=f) expected = pivot_table(data, index="A", columns="B", aggfunc=f_numpy) diff --git a/pandas/tests/util/test_rewrite_warning.py b/pandas/tests/util/test_rewrite_warning.py new file mode 100644 index 0000000000000..f847a06d8ea8d --- /dev/null +++ b/pandas/tests/util/test_rewrite_warning.py @@ -0,0 +1,39 @@ +import warnings + +import pytest + +from pandas.util._exceptions import rewrite_warning + +import pandas._testing as tm + + +@pytest.mark.parametrize( + "target_category, target_message, hit", + [ + (FutureWarning, "Target message", True), + (FutureWarning, "Target", True), + (FutureWarning, "get mess", True), + (FutureWarning, "Missed message", False), + (DeprecationWarning, "Target message", False), + ], +) +@pytest.mark.parametrize( + "new_category", + [ + None, + DeprecationWarning, + ], +) +def test_rewrite_warning(target_category, target_message, hit, new_category): + new_message = "Rewritten message" + if hit: + expected_category = new_category if new_category else target_category + expected_message = new_message + else: + expected_category = FutureWarning + expected_message = "Target message" + with tm.assert_produces_warning(expected_category, match=expected_message): + with rewrite_warning( + target_message, target_category, new_message, new_category + ): + warnings.warn(message="Target message", category=FutureWarning) diff --git a/pandas/util/_exceptions.py b/pandas/util/_exceptions.py index f3a640feb46fc..1eefd06a133fb 100644 --- a/pandas/util/_exceptions.py +++ b/pandas/util/_exceptions.py @@ -3,7 +3,9 @@ import contextlib import inspect import os +import re from typing import Generator +import warnings @contextlib.contextmanager @@ -47,3 +49,46 @@ def find_stack_level() -> int: else: break return n + + +@contextlib.contextmanager +def rewrite_warning( + target_message: str, + target_category: type[Warning], + new_message: str, + new_category: type[Warning] | None = None, +) -> Generator[None, None, None]: + """ + Rewrite the message of a warning. + + Parameters + ---------- + target_message : str + Warning message to match. + target_category : Warning + Warning type to match. + new_message : str + New warning message to emit. + new_category : Warning or None, default None + New warning type to emit. When None, will be the same as target_category. + """ + if new_category is None: + new_category = target_category + with warnings.catch_warnings(record=True) as record: + yield + if len(record) > 0: + match = re.compile(target_message) + for warning in record: + if warning.category is target_category and re.search( + match, str(warning.message) + ): + category = new_category + message: Warning | str = new_message + else: + category, message = warning.category, warning.message + warnings.warn_explicit( + message=message, + category=category, + filename=warning.filename, + lineno=warning.lineno, + )