From af7ee55b13d00967023d88e0151dcf9566229a5a Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 21 Dec 2022 09:41:19 +0100 Subject: [PATCH] API: Ensure a full mask is returned for masked_invalid Matplotlib relies on this, so we don't seem to have much of a choice. I am surprised that we were not notified of the issue before release time. Closes gh-22720, gh-22720 --- numpy/ma/core.py | 7 ++++++- numpy/ma/tests/test_core.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 3ca6e1a23284..640abf628692 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -2357,7 +2357,12 @@ def masked_invalid(a, copy=True): """ a = np.array(a, copy=False, subok=True) - return masked_where(~(np.isfinite(a)), a, copy=copy) + res = masked_where(~(np.isfinite(a)), a, copy=copy) + # masked_invalid previously never returned nomask as a mask and doing so + # threw off matplotlib (gh-22842). So use shrink=False: + if res._mask is nomask: + res._mask = make_mask_none(res.shape, res.dtype) + return res ############################################################################### # Printing options # diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 96864da5db66..47c4a799e7bb 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4519,6 +4519,18 @@ def __array__(self): assert_array_equal(arr._data, np.array(Series())) assert_array_equal(arr._mask, [False, True, True]) + @pytest.mark.parametrize("copy", [True, False]) + def test_masked_invalid_full_mask(self, copy): + # Matplotlib relied on masked_invalid always returning a full mask + # (Also astropy projects, but were ok with it gh-22720 and gh-22842) + a = np.ma.array([1, 2, 3, 4]) + assert a._mask is nomask + res = np.ma.masked_invalid(a, copy=copy) + assert res.mask is not nomask + # mask of a should not be mutated + assert a.mask is nomask + assert np.may_share_memory(a._data, res._data) != copy + def test_choose(self): # Test choose choices = [[0, 1, 2, 3], [10, 11, 12, 13],