diff --git a/doc/source/whatsnew/v1.4.4.rst b/doc/source/whatsnew/v1.4.4.rst index 927841d42b543..e6ff4c65e323f 100644 --- a/doc/source/whatsnew/v1.4.4.rst +++ b/doc/source/whatsnew/v1.4.4.rst @@ -37,6 +37,7 @@ Bug fixes ~~~~~~~~~ - The :class:`errors.FutureWarning` raised when passing arguments (other than ``filepath_or_buffer``) as positional in :func:`read_csv` is now raised at the correct stacklevel (:issue:`47385`) - Bug in :meth:`DataFrame.to_sql` when ``method`` was a ``callable`` that did not return an ``int`` and would raise a ``TypeError`` (:issue:`46891`) +- Bug in :meth:`DataFrameGroupBy.value_counts` where ``subset`` had no effect (:issue:`44267`) - Bug in :meth:`loc.__getitem__` with a list of keys causing an internal inconsistency that could lead to a disconnect between ``frame.at[x, y]`` vs ``frame[y].loc[x]`` (:issue:`22372`) - Bug in the :meth:`Series.dt.strftime` accessor return a float instead of object dtype Series for all-NaT input, which also causes a spurious deprecation warning (:issue:`45858`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index f1a8a70674582..cd91e89554b67 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1805,21 +1805,31 @@ def value_counts( name = self._selected_obj.name keys = [] if name in in_axis_names else [self._selected_obj] else: + unique_cols = set(self._selected_obj.columns) + if subset is not None: + subsetted = set(subset) + clashing = subsetted & set(in_axis_names) + if clashing: + raise ValueError( + f"Keys {clashing} in subset cannot be in " + "the groupby column keys." + ) + doesnt_exist = subsetted - unique_cols + if doesnt_exist: + raise ValueError( + f"Keys {doesnt_exist} in subset do not " + f"exist in the DataFrame." + ) + else: + subsetted = unique_cols + keys = [ # Can't use .values because the column label needs to be preserved self._selected_obj.iloc[:, idx] for idx, name in enumerate(self._selected_obj.columns) - if name not in in_axis_names + if name not in in_axis_names and name in subsetted ] - if subset is not None: - clashing = set(subset) & set(in_axis_names) - if clashing: - raise ValueError( - f"Keys {clashing} in subset cannot be in " - "the groupby column keys" - ) - groupings = list(self.grouper.groupings) for key in keys: grouper, _, _ = get_grouper( diff --git a/pandas/tests/groupby/test_frame_value_counts.py b/pandas/tests/groupby/test_frame_value_counts.py index 1e679ad4e7aad..8255fbab40dce 100644 --- a/pandas/tests/groupby/test_frame_value_counts.py +++ b/pandas/tests/groupby/test_frame_value_counts.py @@ -738,3 +738,46 @@ def test_ambiguous_grouping(): result = gb.value_counts() expected = Series([2], index=MultiIndex.from_tuples([[1, 1]], names=[None, "a"])) tm.assert_series_equal(result, expected) + + +def test_subset_overlaps_gb_key_raises(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + msg = "Keys {'c1'} in subset cannot be in the groupby column keys." + with pytest.raises(ValueError, match=msg): + df.groupby("c1").value_counts(subset=["c1"]) + + +def test_subset_doesnt_exist_in_frame(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + msg = "Keys {'c3'} in subset do not exist in the DataFrame." + with pytest.raises(ValueError, match=msg): + df.groupby("c1").value_counts(subset=["c3"]) + + +def test_subset(): + # GH 46383 + df = DataFrame({"c1": ["a", "b", "c"], "c2": ["x", "y", "y"]}, index=[0, 1, 1]) + result = df.groupby(level=0).value_counts(subset=["c2"]) + expected = Series( + [1, 2], index=MultiIndex.from_arrays([[0, 1], ["x", "y"]], names=[None, "c2"]) + ) + tm.assert_series_equal(result, expected) + + +def test_subset_duplicate_columns(): + # GH 46383 + df = DataFrame( + [["a", "x", "x"], ["b", "y", "y"], ["b", "y", "y"]], + index=[0, 1, 1], + columns=["c1", "c2", "c2"], + ) + result = df.groupby(level=0).value_counts(subset=["c2"]) + expected = Series( + [1, 2], + index=MultiIndex.from_arrays( + [[0, 1], ["x", "y"], ["x", "y"]], names=[None, "c2", "c2"] + ), + ) + tm.assert_series_equal(result, expected)