Skip to content

Commit

Permalink
BUG: pivot_table raising for nullable dtype and margins (#48714)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Sep 22, 2022
1 parent 26d1cec commit 36a67f6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ Groupby/resample/rolling

Reshaping
^^^^^^^^^
- Bug in :meth:`DataFrame.pivot_table` raising ``TypeError`` for nullable dtype and ``margins=True`` (:issue:`48681`)
- Bug in :meth:`DataFrame.pivot` not respecting ``None`` as column name (:issue:`48293`)
- Bug in :func:`join` when ``left_on`` or ``right_on`` is or includes a :class:`CategoricalIndex` incorrectly raising ``AttributeError`` (:issue:`48464`)
-
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.common import (
is_extension_array_dtype,
is_integer_dtype,
is_list_like,
is_nested_list_like,
Expand Down Expand Up @@ -324,6 +325,10 @@ def _add_margins(
row_names = result.index.names
# check the result column and leave floats
for dtype in set(result.dtypes):
if is_extension_array_dtype(dtype):
# Can hold NA already
continue

cols = result.select_dtypes([dtype]).columns
margin_dummy[cols] = margin_dummy[cols].apply(
maybe_downcast_to_dtype, args=(dtype,)
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/reshape/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,23 @@ def test_pivot_table_sort_false(self):
)
tm.assert_frame_equal(result, expected)

def test_pivot_table_nullable_margins(self):
# GH#48681
df = DataFrame(
{"a": "A", "b": [1, 2], "sales": Series([10, 11], dtype="Int64")}
)

result = df.pivot_table(index="b", columns="a", margins=True, aggfunc="sum")
expected = DataFrame(
[[10, 10], [11, 11], [21, 21]],
index=Index([1, 2, "All"], name="b"),
columns=MultiIndex.from_tuples(
[("sales", "A"), ("sales", "All")], names=[None, "a"]
),
dtype="Int64",
)
tm.assert_frame_equal(result, expected)

def test_pivot_table_sort_false_with_multiple_values(self):
df = DataFrame(
{
Expand Down

0 comments on commit 36a67f6

Please sign in to comment.