diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index f10f68ea5fce..3e8c78dd0264 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2829,7 +2829,7 @@ array_class_getitem(PyObject *cls, PyObject *args) Py_ssize_t args_len; args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != 2) { + if ((args_len > 2) || (args_len == 0)) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", args_len > 2 ? "many" : "few", diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 459e5b222f2c..e1f23600182f 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1855,7 +1855,7 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args) } args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != args_len_expected) { + if ((args_len > args_len_expected) || (args_len == 0)) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", args_len > args_len_expected ? "many" : "few", diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py index 49aa9f6dfcfa..6b75d192121d 100644 --- a/numpy/core/tests/test_arraymethod.py +++ b/numpy/core/tests/test_arraymethod.py @@ -3,9 +3,11 @@ this is private API, but when added, public API may be added here. """ +from __future__ import annotations + import sys import types -from typing import Any, Type +from typing import Any import pytest @@ -63,28 +65,25 @@ def test_invalid_arguments(self, args, error): @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +@pytest.mark.parametrize( + "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] +) class TestClassGetItem: - @pytest.mark.parametrize( - "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] - ) - def test_class_getitem(self, cls: Type[np.ndarray]) -> None: + def test_class_getitem(self, cls: type[np.ndarray]) -> None: """Test `ndarray.__class_getitem__`.""" alias = cls[Any, Any] assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is cls @pytest.mark.parametrize("arg_len", range(4)) - def test_subscript_tuple(self, arg_len: int) -> None: + def test_subscript_tup(self, cls: type[np.ndarray], arg_len: int) -> None: arg_tup = (Any,) * arg_len - if arg_len == 2: - assert np.ndarray[arg_tup] + if arg_len in (1, 2): + assert cls[arg_tup] else: - with pytest.raises(TypeError): - np.ndarray[arg_tup] - - def test_subscript_scalar(self) -> None: - with pytest.raises(TypeError): - np.ndarray[Any] + match = f"Too {'few' if arg_len == 0 else 'many'} arguments" + with pytest.raises(TypeError, match=match): + cls[arg_tup] @pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index eef4c1433910..769bfd5006db 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -153,6 +153,16 @@ def test_abc_complexfloating(self) -> None: assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is np.complexfloating + @pytest.mark.parametrize("arg_len", range(4)) + def test_abc_complexfloating_subscript_tuple(self, arg_len: int) -> None: + arg_tup = (Any,) * arg_len + if arg_len in (1, 2): + assert np.complexfloating[arg_tup] + else: + match = f"Too {'few' if arg_len == 0 else 'many'} arguments" + with pytest.raises(TypeError, match=match): + np.complexfloating[arg_tup] + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: with pytest.raises(TypeError):