From 9fefc8f7a92fe08583466c668e07585b4e421021 Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Thu, 10 Nov 2022 20:39:41 -0500 Subject: [PATCH] BUG: groupby.nth should be a filter (#49262) --- doc/source/user_guide/groupby.rst | 38 +++--- doc/source/whatsnew/v2.0.0.rst | 56 ++++++++- pandas/core/groupby/base.py | 2 +- pandas/core/groupby/groupby.py | 131 +++++++------------- pandas/tests/groupby/test_categorical.py | 6 +- pandas/tests/groupby/test_function.py | 1 - pandas/tests/groupby/test_grouping.py | 2 + pandas/tests/groupby/test_nth.py | 151 +++++++++-------------- 8 files changed, 179 insertions(+), 208 deletions(-) diff --git a/doc/source/user_guide/groupby.rst b/doc/source/user_guide/groupby.rst index dae42dd4f1118..d8a36b1711b6e 100644 --- a/doc/source/user_guide/groupby.rst +++ b/doc/source/user_guide/groupby.rst @@ -1354,9 +1354,14 @@ This shows the first or last n rows from each group. Taking the nth row of each group ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To select from a DataFrame or Series the nth item, use -:meth:`~pd.core.groupby.DataFrameGroupBy.nth`. This is a reduction method, and -will return a single row (or no row) per group if you pass an int for n: +To select the nth item from each group, use :meth:`.DataFrameGroupBy.nth` or +:meth:`.SeriesGroupBy.nth`. Arguments supplied can be any integer, lists of integers, +slices, or lists of slices; see below for examples. When the nth element of a group +does not exist an error is *not* raised; instead no corresponding rows are returned. + +In general this operation acts as a filtration. In certain cases it will also return +one row per group, making it also a reduction. However because in general it can +return zero or multiple rows per group, pandas treats it as a filtration in all cases. .. ipython:: python @@ -1367,6 +1372,14 @@ will return a single row (or no row) per group if you pass an int for n: g.nth(-1) g.nth(1) +If the nth element of a group does not exist, then no corresponding row is included +in the result. In particular, if the specified ``n`` is larger than any group, the +result will be an empty DataFrame. + +.. ipython:: python + + g.nth(5) + If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna: .. ipython:: python @@ -1376,21 +1389,11 @@ If you want to select the nth not-null item, use the ``dropna`` kwarg. For a Dat g.first() # nth(-1) is the same as g.last() - g.nth(-1, dropna="any") # NaNs denote group exhausted when using dropna + g.nth(-1, dropna="any") g.last() g.B.nth(0, dropna="all") -As with other methods, passing ``as_index=False``, will achieve a filtration, which returns the grouped row. - -.. ipython:: python - - df = pd.DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) - g = df.groupby("A", as_index=False) - - g.nth(0) - g.nth(-1) - You can also select multiple rows from each group by specifying multiple nth values as a list of ints. .. ipython:: python @@ -1400,6 +1403,13 @@ You can also select multiple rows from each group by specifying multiple nth val # get the first, 4th, and last date index for each month df.groupby([df.index.year, df.index.month]).nth([0, 3, -1]) +You may also use a slices or lists of slices. + +.. ipython:: python + + df.groupby([df.index.year, df.index.month]).nth[1:] + df.groupby([df.index.year, df.index.month]).nth[1:, :-1] + Enumerate group items ~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index f81660ba2136b..715ba95eb950b 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -72,7 +72,7 @@ Notable bug fixes These are bug fixes that might have notable behavior changes. -.. _whatsnew_200.notable_bug_fixes.notable_bug_fix1: +.. _whatsnew_200.notable_bug_fixes.cumsum_cumprod_overflow: :meth:`.GroupBy.cumsum` and :meth:`.GroupBy.cumprod` overflow instead of lossy casting to float ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -102,10 +102,58 @@ We return incorrect results with the 6th value. We overflow with the 7th value, but the 6th value is still correct. -.. _whatsnew_200.notable_bug_fixes.notable_bug_fix2: +.. _whatsnew_200.notable_bug_fixes.groupby_nth_filter: -notable_bug_fix2 -^^^^^^^^^^^^^^^^ +:meth:`.DataFrameGroupBy.nth` and :meth:`.SeriesGroupBy.nth` now behave as filtrations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In previous versions of pandas, :meth:`.DataFrameGroupBy.nth` and +:meth:`.SeriesGroupBy.nth` acted as if they were aggregations. However, for most +inputs ``n``, they may return either zero or multiple rows per group. This means +that they are filtrations, similar to e.g. :meth:`.DataFrameGroupBy.head`. pandas +now treats them as filtrations (:issue:`13666`). + +.. ipython:: python + + df = pd.DataFrame({"a": [1, 1, 2, 1, 2], "b": [np.nan, 2.0, 3.0, 4.0, 5.0]}) + gb = df.groupby("a") + +*Old Behavior* + +.. code-block:: ipython + + In [5]: gb.nth(n=1) + Out[5]: + A B + 1 1 2.0 + 4 2 5.0 + +*New Behavior* + +.. ipython:: python + + gb.nth(n=1) + +In particular, the index of the result is derived from the input by selecting +the appropriate rows. Also, when ``n`` is larger than the group, no rows instead of +``NaN`` is returned. + +*Old Behavior* + +.. code-block:: ipython + + In [5]: gb.nth(n=3, dropna="any") + Out[5]: + B + A + 1 NaN + 2 NaN + +*New Behavior* + +.. ipython:: python + + gb.nth(n=3, dropna="any") .. --------------------------------------------------------------------------- .. _whatsnew_200.api_breaking: diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index a2e9c059cbcc9..0f6d39be7d32f 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -37,7 +37,6 @@ class OutputKey: "mean", "median", "min", - "nth", "nunique", "prod", # as long as `quantile`'s signature accepts only @@ -100,6 +99,7 @@ class OutputKey: "indices", "ndim", "ngroups", + "nth", "ohlc", "pipe", "plot", diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f1c18b7762f66..d10931586d5e0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2978,97 +2978,68 @@ def nth( ... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B']) >>> g = df.groupby('A') >>> g.nth(0) - B - A - 1 NaN - 2 3.0 + A B + 0 1 NaN + 2 2 3.0 >>> g.nth(1) - B - A - 1 2.0 - 2 5.0 + A B + 1 1 2.0 + 4 2 5.0 >>> g.nth(-1) - B - A - 1 4.0 - 2 5.0 + A B + 3 1 4.0 + 4 2 5.0 >>> g.nth([0, 1]) - B - A - 1 NaN - 1 2.0 - 2 3.0 - 2 5.0 + A B + 0 1 NaN + 1 1 2.0 + 2 2 3.0 + 4 2 5.0 >>> g.nth(slice(None, -1)) - B - A - 1 NaN - 1 2.0 - 2 3.0 + A B + 0 1 NaN + 1 1 2.0 + 2 2 3.0 Index notation may also be used >>> g.nth[0, 1] - B - A - 1 NaN - 1 2.0 - 2 3.0 - 2 5.0 + A B + 0 1 NaN + 1 1 2.0 + 2 2 3.0 + 4 2 5.0 >>> g.nth[:-1] - B - A - 1 NaN - 1 2.0 - 2 3.0 + A B + 0 1 NaN + 1 1 2.0 + 2 2 3.0 - Specifying `dropna` allows count ignoring ``NaN`` + Specifying `dropna` allows ignoring ``NaN`` values >>> g.nth(0, dropna='any') - B - A - 1 2.0 - 2 3.0 + A B + 1 1 2.0 + 2 2 3.0 - NaNs denote group exhausted when using dropna + When the specified ``n`` is larger than any of the groups, an + empty DataFrame is returned >>> g.nth(3, dropna='any') - B - A - 1 NaN - 2 NaN - - Specifying `as_index=False` in `groupby` keeps the original index. - - >>> df.groupby('A', as_index=False).nth(1) - A B - 1 1 2.0 - 4 2 5.0 + Empty DataFrame + Columns: [A, B] + Index: [] """ if not dropna: - with self._group_selection_context(): - mask = self._make_mask_from_positional_indexer(n) + mask = self._make_mask_from_positional_indexer(n) - ids, _, _ = self.grouper.group_info + ids, _, _ = self.grouper.group_info - # Drop NA values in grouping - mask = mask & (ids != -1) + # Drop NA values in grouping + mask = mask & (ids != -1) - out = self._mask_selected_obj(mask) - if not self.as_index: - return out - - result_index = self.grouper.result_index - if self.axis == 0: - out.index = result_index[ids[mask]] - if not self.observed and isinstance(result_index, CategoricalIndex): - out = out.reindex(result_index) - - out = self._reindex_output(out) - else: - out.columns = result_index[ids[mask]] - - return out.sort_index(axis=self.axis) if self.sort else out + out = self._mask_selected_obj(mask) + return out # dropna is truthy if not is_integer(n): @@ -3085,7 +3056,6 @@ def nth( # old behaviour, but with all and any support for DataFrames. # modified in GH 7559 to have better perf n = cast(int, n) - max_len = n if n >= 0 else -1 - n dropped = self.obj.dropna(how=dropna, axis=self.axis) # get a new grouper for our dropped obj @@ -3115,22 +3085,7 @@ def nth( grb = dropped.groupby( grouper, as_index=self.as_index, sort=self.sort, axis=self.axis ) - sizes, result = grb.size(), grb.nth(n) - mask = (sizes < max_len)._values - - # set the results which don't meet the criteria - if len(result) and mask.any(): - result.loc[mask] = np.nan - - # reset/reindex to the original groups - if len(self.obj) == len(dropped) or len(result) == len( - self.grouper.result_index - ): - result.index = self.grouper.result_index - else: - result = result.reindex(self.grouper.result_index) - - return result + return grb.nth(n) @final def quantile( diff --git a/pandas/tests/groupby/test_categorical.py b/pandas/tests/groupby/test_categorical.py index 1e2bcb58110dd..ca794d4ae5a3e 100644 --- a/pandas/tests/groupby/test_categorical.py +++ b/pandas/tests/groupby/test_categorical.py @@ -563,11 +563,7 @@ def test_observed_nth(): df = DataFrame({"cat": cat, "ser": ser}) result = df.groupby("cat", observed=False)["ser"].nth(0) - - index = Categorical(["a", "b", "c"], categories=["a", "b", "c"]) - expected = Series([1, np.nan, np.nan], index=index, name="ser") - expected.index.name = "cat" - + expected = df["ser"].iloc[[0]] tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 5383a4d28c8ce..f05874c3286c7 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -405,7 +405,6 @@ def test_median_empty_bins(observed): ("last", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), ("min", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}), ("max", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), - ("nth", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}], "args": [1]}), ("count", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 2}], "out_type": "int64"}), ], ) diff --git a/pandas/tests/groupby/test_grouping.py b/pandas/tests/groupby/test_grouping.py index 1c8b8e3d33ecf..e3b7ad8f78750 100644 --- a/pandas/tests/groupby/test_grouping.py +++ b/pandas/tests/groupby/test_grouping.py @@ -851,6 +851,8 @@ def test_groupby_with_single_column(self): exp = DataFrame(index=Index(["a", "b", "s"], name="a")) tm.assert_frame_equal(df.groupby("a").count(), exp) tm.assert_frame_equal(df.groupby("a").sum(), exp) + + exp = df.iloc[[3, 4, 5]] tm.assert_frame_equal(df.groupby("a").nth(1), exp) def test_gb_key_len_equal_axis_len(self): diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 187c80075f36b..de5025b998b30 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -23,6 +23,7 @@ def test_first_last_nth(df): tm.assert_frame_equal(first, expected) nth = grouped.nth(0) + expected = df.loc[[0, 1]] tm.assert_frame_equal(nth, expected) last = grouped.last() @@ -31,12 +32,11 @@ def test_first_last_nth(df): tm.assert_frame_equal(last, expected) nth = grouped.nth(-1) + expected = df.iloc[[5, 7]] tm.assert_frame_equal(nth, expected) nth = grouped.nth(1) - expected = df.loc[[2, 3], ["B", "C", "D"]].copy() - expected.index = Index(["foo", "bar"], name="A") - expected = expected.sort_index() + expected = df.iloc[[2, 3]] tm.assert_frame_equal(nth, expected) # it works! @@ -47,7 +47,7 @@ def test_first_last_nth(df): df.loc[df["A"] == "foo", "B"] = np.nan assert isna(grouped["B"].first()["foo"]) assert isna(grouped["B"].last()["foo"]) - assert isna(grouped["B"].nth(0)["foo"]) + assert isna(grouped["B"].nth(0).iloc[0]) # v0.14.0 whatsnew df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) @@ -56,7 +56,7 @@ def test_first_last_nth(df): expected = df.iloc[[1, 2]].set_index("A") tm.assert_frame_equal(result, expected) - expected = df.iloc[[1, 2]].set_index("A") + expected = df.iloc[[1, 2]] result = g.nth(0, dropna="any") tm.assert_frame_equal(result, expected) @@ -82,18 +82,10 @@ def test_first_last_with_na_object(method, nulls_fixture): @pytest.mark.parametrize("index", [0, -1]) def test_nth_with_na_object(index, nulls_fixture): # https://github.com/pandas-dev/pandas/issues/32123 - groups = DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}).groupby("a") + df = DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, nulls_fixture]}) + groups = df.groupby("a") result = groups.nth(index) - - if index == 0: - values = [1, 3] - else: - values = [2, nulls_fixture] - - values = np.array(values, dtype=result["b"].dtype) - idx = Index([1, 2], name="a") - expected = DataFrame({"b": values}, index=idx) - + expected = df.iloc[[0, 2]] if index == 0 else df.iloc[[1, 3]] tm.assert_frame_equal(result, expected) @@ -149,9 +141,7 @@ def test_first_last_nth_dtypes(df_mixed_floats): tm.assert_frame_equal(last, expected) nth = grouped.nth(1) - expected = df.loc[[3, 2], ["B", "C", "D", "E", "F"]] - expected.index = Index(["bar", "foo"], name="A") - expected = expected.sort_index() + expected = df.iloc[[2, 3]] tm.assert_frame_equal(nth, expected) # GH 2763, first/last shifting dtypes @@ -166,11 +156,13 @@ def test_first_last_nth_dtypes(df_mixed_floats): def test_first_last_nth_nan_dtype(): # GH 33591 df = DataFrame({"data": ["A"], "nans": Series([np.nan], dtype=object)}) - grouped = df.groupby("data") + expected = df.set_index("data").nans tm.assert_series_equal(grouped.nans.first(), expected) tm.assert_series_equal(grouped.nans.last(), expected) + + expected = df.nans tm.assert_series_equal(grouped.nans.nth(-1), expected) tm.assert_series_equal(grouped.nans.nth(0), expected) @@ -198,23 +190,21 @@ def test_nth(): df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) g = df.groupby("A") - tm.assert_frame_equal(g.nth(0), df.iloc[[0, 2]].set_index("A")) - tm.assert_frame_equal(g.nth(1), df.iloc[[1]].set_index("A")) - tm.assert_frame_equal(g.nth(2), df.loc[[]].set_index("A")) - tm.assert_frame_equal(g.nth(-1), df.iloc[[1, 2]].set_index("A")) - tm.assert_frame_equal(g.nth(-2), df.iloc[[0]].set_index("A")) - tm.assert_frame_equal(g.nth(-3), df.loc[[]].set_index("A")) - tm.assert_series_equal(g.B.nth(0), df.set_index("A").B.iloc[[0, 2]]) - tm.assert_series_equal(g.B.nth(1), df.set_index("A").B.iloc[[1]]) - tm.assert_frame_equal(g[["B"]].nth(0), df.loc[[0, 2], ["A", "B"]].set_index("A")) + tm.assert_frame_equal(g.nth(0), df.iloc[[0, 2]]) + tm.assert_frame_equal(g.nth(1), df.iloc[[1]]) + tm.assert_frame_equal(g.nth(2), df.loc[[]]) + tm.assert_frame_equal(g.nth(-1), df.iloc[[1, 2]]) + tm.assert_frame_equal(g.nth(-2), df.iloc[[0]]) + tm.assert_frame_equal(g.nth(-3), df.loc[[]]) + tm.assert_series_equal(g.B.nth(0), df.B.iloc[[0, 2]]) + tm.assert_series_equal(g.B.nth(1), df.B.iloc[[1]]) + tm.assert_frame_equal(g[["B"]].nth(0), df[["B"]].iloc[[0, 2]]) - exp = df.set_index("A") - tm.assert_frame_equal(g.nth(0, dropna="any"), exp.iloc[[1, 2]]) - tm.assert_frame_equal(g.nth(-1, dropna="any"), exp.iloc[[1, 2]]) + tm.assert_frame_equal(g.nth(0, dropna="any"), df.iloc[[1, 2]]) + tm.assert_frame_equal(g.nth(-1, dropna="any"), df.iloc[[1, 2]]) - exp["B"] = np.nan - tm.assert_frame_equal(g.nth(7, dropna="any"), exp.iloc[[1, 2]]) - tm.assert_frame_equal(g.nth(2, dropna="any"), exp.iloc[[1, 2]]) + tm.assert_frame_equal(g.nth(7, dropna="any"), df.iloc[:0]) + tm.assert_frame_equal(g.nth(2, dropna="any"), df.iloc[:0]) # out of bounds, regression from 0.13.1 # GH 6621 @@ -263,13 +253,6 @@ def test_nth(): assert expected.iloc[0] == v assert expected2.iloc[0] == v - # this is NOT the same as .first (as sorted is default!) - # as it keeps the order in the series (and not the group order) - # related GH 7287 - expected = s.groupby(g, sort=False).first() - result = s.groupby(g, sort=False).nth(0, dropna="all") - tm.assert_series_equal(result, expected) - with pytest.raises(ValueError, match="For a DataFrame"): s.groupby(g, sort=False).nth(0, dropna=True) @@ -277,21 +260,21 @@ def test_nth(): df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=["A", "B"]) g = df.groupby("A") result = g.B.nth(0, dropna="all") - expected = g.B.first() + expected = df.B.iloc[[1, 2]] tm.assert_series_equal(result, expected) # test multiple nth values df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]], columns=["A", "B"]) g = df.groupby("A") - tm.assert_frame_equal(g.nth(0), df.iloc[[0, 3]].set_index("A")) - tm.assert_frame_equal(g.nth([0]), df.iloc[[0, 3]].set_index("A")) - tm.assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]].set_index("A")) - tm.assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]].set_index("A")) - tm.assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]].set_index("A")) - tm.assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index("A")) - tm.assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index("A")) - tm.assert_frame_equal(g.nth([3, 4]), df.loc[[]].set_index("A")) + tm.assert_frame_equal(g.nth(0), df.iloc[[0, 3]]) + tm.assert_frame_equal(g.nth([0]), df.iloc[[0, 3]]) + tm.assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]]) + tm.assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]]) + tm.assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]]) + tm.assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]]) + tm.assert_frame_equal(g.nth([2]), df.iloc[[2]]) + tm.assert_frame_equal(g.nth([3, 4]), df.loc[[]]) business_dates = pd.date_range(start="4/1/2014", end="6/30/2014", freq="B") df = DataFrame(1, index=business_dates, columns=["a", "b"]) @@ -318,12 +301,12 @@ def test_nth(): tm.assert_frame_equal(result, expected) -def test_nth_multi_index(three_group): +def test_nth_multi_grouper(three_group): # PR 9090, related to issue 8979 - # test nth on MultiIndex, should match .first() + # test nth on multiple groupers grouped = three_group.groupby(["A", "B"]) result = grouped.nth(0) - expected = grouped.first() + expected = three_group.iloc[[0, 3, 4, 7]] tm.assert_frame_equal(result, expected) @@ -504,13 +487,7 @@ def test_nth_multi_index_as_expected(): ) grouped = three_group.groupby(["A", "B"]) result = grouped.nth(0) - expected = DataFrame( - {"C": ["dull", "dull", "dull", "dull"]}, - index=MultiIndex.from_arrays( - [["bar", "bar", "foo", "foo"], ["one", "two", "one", "two"]], - names=["A", "B"], - ), - ) + expected = three_group.iloc[[0, 3, 4, 7]] tm.assert_frame_equal(result, expected) @@ -567,7 +544,7 @@ def test_groupby_head_tail_axis_1(op, n, expected_cols): def test_group_selection_cache(): # GH 12839 nth, head, and tail should return same result consistently df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=["A", "B"]) - expected = df.iloc[[0, 2]].set_index("A") + expected = df.iloc[[0, 2]] g = df.groupby("A") result1 = g.head(n=2) @@ -598,13 +575,11 @@ def test_nth_empty(): # GH 16064 df = DataFrame(index=[0], columns=["a", "b", "c"]) result = df.groupby("a").nth(10) - expected = DataFrame(index=Index([], name="a"), columns=["b", "c"]) + expected = df.iloc[:0] tm.assert_frame_equal(result, expected) result = df.groupby(["a", "b"]).nth(10) - expected = DataFrame( - index=MultiIndex([[], []], [[], []], names=["a", "b"]), columns=["c"] - ) + expected = df.iloc[:0] tm.assert_frame_equal(result, expected) @@ -616,15 +591,11 @@ def test_nth_column_order(): columns=["A", "C", "B"], ) result = df.groupby("A").nth(0) - expected = DataFrame( - [["b", 100.0], ["c", 200.0]], columns=["C", "B"], index=Index([1, 2], name="A") - ) + expected = df.iloc[[0, 3]] tm.assert_frame_equal(result, expected) result = df.groupby("A").nth(-1, dropna="any") - expected = DataFrame( - [["a", 50.0], ["d", 150.0]], columns=["C", "B"], index=Index([1, 2], name="A") - ) + expected = df.iloc[[1, 4]] tm.assert_frame_equal(result, expected) @@ -636,9 +607,7 @@ def test_nth_nan_in_grouper(dropna): columns=list("abc"), ) result = df.groupby("a").nth(0, dropna=dropna) - expected = DataFrame( - [[2, 3], [6, 7]], columns=list("bc"), index=Index(["abc", "def"], name="a") - ) + expected = df.iloc[[1, 3]] tm.assert_frame_equal(result, expected) @@ -772,29 +741,21 @@ def test_groupby_nth_with_column_axis(): columns=["C", "B", "A"], ) result = df.groupby(df.iloc[1], axis=1).nth(0) - expected = DataFrame( - [ - [6, 4], - [7, 8], - ], - index=["z", "y"], - columns=[7, 8], - ) - expected.columns.name = "y" + expected = df.iloc[:, [0, 2]] tm.assert_frame_equal(result, expected) @pytest.mark.parametrize( "start, stop, expected_values, expected_columns", [ - (None, None, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]), - (None, 1, [0, 3], [5, 6]), - (None, 9, [0, 1, 2, 3, 4], [5, 5, 5, 6, 6]), - (None, -1, [0, 1, 3], [5, 5, 6]), - (1, None, [1, 2, 4], [5, 5, 6]), - (1, -1, [1], [5]), - (-1, None, [2, 4], [5, 6]), - (-1, 2, [4], [6]), + (None, None, [0, 1, 2, 3, 4], list("ABCDE")), + (None, 1, [0, 3], list("AD")), + (None, 9, [0, 1, 2, 3, 4], list("ABCDE")), + (None, -1, [0, 1, 3], list("ABD")), + (1, None, [1, 2, 4], list("BCE")), + (1, -1, [1], list("B")), + (-1, None, [2, 4], list("CE")), + (-1, 2, [4], list("E")), ], ) @pytest.mark.parametrize("method", ["call", "index"]) @@ -807,7 +768,7 @@ def test_nth_slices_with_column_axis( "call": lambda start, stop: gb.nth(slice(start, stop)), "index": lambda start, stop: gb.nth[start:stop], }[method](start, stop) - expected = DataFrame([expected_values], columns=expected_columns) + expected = DataFrame([expected_values], columns=[expected_columns]) tm.assert_frame_equal(result, expected) @@ -824,7 +785,7 @@ def test_head_tail_dropna_true(): result = df.groupby(["X", "Y"]).tail(n=1) tm.assert_frame_equal(result, expected) - result = df.groupby(["X", "Y"]).nth(n=0).reset_index() + result = df.groupby(["X", "Y"]).nth(n=0) tm.assert_frame_equal(result, expected) @@ -839,5 +800,5 @@ def test_head_tail_dropna_false(): result = df.groupby(["X", "Y"], dropna=False).tail(n=1) tm.assert_frame_equal(result, expected) - result = df.groupby(["X", "Y"], dropna=False).nth(n=0).reset_index() + result = df.groupby(["X", "Y"], dropna=False).nth(n=0) tm.assert_frame_equal(result, expected)