From 222cc37acbfe3ef7d26bfe31e70e76642ac1b49b Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Wed, 7 Sep 2022 09:22:54 +0200 Subject: [PATCH] TYP,BUG: Reduce argument validation in C-based `__class_getitem__` (#22212) Closes #22185 The __class_getitem__ implementations would previously perform basic validation of the passed value, i.e. it would check whether a tuple of the appropriate length was passed (e.g. np.dtype.__class_getitem__ would expect a single item or a length-1 tuple). As noted in aforementioned issue: this approach can cause issues when (a. 2 or more parameters are involved and (b. a subclasses is created one or more parameters are declared constant (e.g. a fixed dtype & variably shaped array). This PR fixes aforementioned issue by relaxing the runtime argument validation, thus mimicking the behavior of the standard library (more closely). While we could alternatively fix this by adding more special casing (e.g. only disable validation when cls is not np.ndarray), I'm not convinced this would be worth the additional complexity, especially since the standard library also has zero runtime validation for all of its Py_GenericAlias-based implementations of __class_getitem__. (Some edits by seberg to the commit message) --- numpy/core/src/multiarray/methods.c | 2 +- numpy/core/src/multiarray/scalartypes.c.src | 2 +- numpy/core/tests/test_arraymethod.py | 27 ++++++++++----------- numpy/core/tests/test_scalar_methods.py | 10 ++++++++ 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index f10f68ea5fce..3e8c78dd0264 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2829,7 +2829,7 @@ array_class_getitem(PyObject *cls, PyObject *args) Py_ssize_t args_len; args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != 2) { + if ((args_len > 2) || (args_len == 0)) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", args_len > 2 ? "many" : "few", diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 459e5b222f2c..e1f23600182f 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1855,7 +1855,7 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args) } args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != args_len_expected) { + if ((args_len > args_len_expected) || (args_len == 0)) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", args_len > args_len_expected ? "many" : "few", diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py index 49aa9f6dfcfa..6b75d192121d 100644 --- a/numpy/core/tests/test_arraymethod.py +++ b/numpy/core/tests/test_arraymethod.py @@ -3,9 +3,11 @@ this is private API, but when added, public API may be added here. """ +from __future__ import annotations + import sys import types -from typing import Any, Type +from typing import Any import pytest @@ -63,28 +65,25 @@ def test_invalid_arguments(self, args, error): @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +@pytest.mark.parametrize( + "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] +) class TestClassGetItem: - @pytest.mark.parametrize( - "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] - ) - def test_class_getitem(self, cls: Type[np.ndarray]) -> None: + def test_class_getitem(self, cls: type[np.ndarray]) -> None: """Test `ndarray.__class_getitem__`.""" alias = cls[Any, Any] assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is cls @pytest.mark.parametrize("arg_len", range(4)) - def test_subscript_tuple(self, arg_len: int) -> None: + def test_subscript_tup(self, cls: type[np.ndarray], arg_len: int) -> None: arg_tup = (Any,) * arg_len - if arg_len == 2: - assert np.ndarray[arg_tup] + if arg_len in (1, 2): + assert cls[arg_tup] else: - with pytest.raises(TypeError): - np.ndarray[arg_tup] - - def test_subscript_scalar(self) -> None: - with pytest.raises(TypeError): - np.ndarray[Any] + match = f"Too {'few' if arg_len == 0 else 'many'} arguments" + with pytest.raises(TypeError, match=match): + cls[arg_tup] @pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index eef4c1433910..769bfd5006db 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -153,6 +153,16 @@ def test_abc_complexfloating(self) -> None: assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is np.complexfloating + @pytest.mark.parametrize("arg_len", range(4)) + def test_abc_complexfloating_subscript_tuple(self, arg_len: int) -> None: + arg_tup = (Any,) * arg_len + if arg_len in (1, 2): + assert np.complexfloating[arg_tup] + else: + match = f"Too {'few' if arg_len == 0 else 'many'} arguments" + with pytest.raises(TypeError, match=match): + np.complexfloating[arg_tup] + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: with pytest.raises(TypeError):