diff --git a/doc/source/whatsnew/v1.5.2.rst b/doc/source/whatsnew/v1.5.2.rst index a46d081aea072..446235d1656dc 100644 --- a/doc/source/whatsnew/v1.5.2.rst +++ b/doc/source/whatsnew/v1.5.2.rst @@ -13,6 +13,7 @@ including other versions of pandas. Fixed regressions ~~~~~~~~~~~~~~~~~ +- Fixed regression in :meth:`MultiIndex.join` for extension array dtypes (:issue:`49277`) - Fixed regression in :meth:`Series.replace` raising ``RecursionError`` with numeric dtype and when specifying ``value=None`` (:issue:`45725`) - Fixed regression in :meth:`DataFrame.plot` preventing :class:`~matplotlib.colors.Colormap` instance from being passed using the ``colormap`` argument if Matplotlib 3.6+ is used (:issue:`49374`) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 5abd04c29e5d4..677e1dc1a559a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4701,8 +4701,10 @@ def join( return self._join_non_unique(other, how=how) elif not self.is_unique or not other.is_unique: if self.is_monotonic_increasing and other.is_monotonic_increasing: - if self._can_use_libjoin: + if not is_interval_dtype(self.dtype): # otherwise we will fall through to _join_via_get_indexer + # GH#39133 + # go through object dtype for ea till engine is supported properly return self._join_monotonic(other, how=how) else: return self._join_non_unique(other, how=how) @@ -5079,7 +5081,7 @@ def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _Ind return self._constructor(joined, name=name) # type: ignore[return-value] else: name = get_op_result_name(self, other) - return self._constructor._with_infer(joined, name=name) + return self._constructor._with_infer(joined, name=name, dtype=self.dtype) @cache_readonly def _can_use_libjoin(self) -> bool: diff --git a/pandas/tests/indexes/multi/test_join.py b/pandas/tests/indexes/multi/test_join.py index e6bec97aedb38..7000724a6b271 100644 --- a/pandas/tests/indexes/multi/test_join.py +++ b/pandas/tests/indexes/multi/test_join.py @@ -5,6 +5,8 @@ Index, Interval, MultiIndex, + Series, + StringDtype, ) import pandas._testing as tm @@ -158,3 +160,49 @@ def test_join_overlapping_interval_level(): result = idx_1.join(idx_2, how="outer") tm.assert_index_equal(result, expected) + + +def test_join_midx_ea(): + # GH#49277 + midx = MultiIndex.from_arrays( + [Series([1, 1, 3], dtype="Int64"), Series([1, 2, 3], dtype="Int64")], + names=["a", "b"], + ) + midx2 = MultiIndex.from_arrays( + [Series([1], dtype="Int64"), Series([3], dtype="Int64")], names=["a", "c"] + ) + result = midx.join(midx2, how="inner") + expected = MultiIndex.from_arrays( + [ + Series([1, 1], dtype="Int64"), + Series([1, 2], dtype="Int64"), + Series([3, 3], dtype="Int64"), + ], + names=["a", "b", "c"], + ) + tm.assert_index_equal(result, expected) + + +def test_join_midx_string(): + # GH#49277 + midx = MultiIndex.from_arrays( + [ + Series(["a", "a", "c"], dtype=StringDtype()), + Series(["a", "b", "c"], dtype=StringDtype()), + ], + names=["a", "b"], + ) + midx2 = MultiIndex.from_arrays( + [Series(["a"], dtype=StringDtype()), Series(["c"], dtype=StringDtype())], + names=["a", "c"], + ) + result = midx.join(midx2, how="inner") + expected = MultiIndex.from_arrays( + [ + Series(["a", "a"], dtype=StringDtype()), + Series(["a", "b"], dtype=StringDtype()), + Series(["c", "c"], dtype=StringDtype()), + ], + names=["a", "b", "c"], + ) + tm.assert_index_equal(result, expected)