Skip to content

Commit

Permalink
Merge pull request #22222 from charris/backport-22212
Browse files Browse the repository at this point in the history
TYP,BUG: Reduce argument validation in C-based ``__class_getitem__``
  • Loading branch information
charris committed Sep 7, 2022
2 parents 63ab75d + 222cc37 commit 754ec89
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/methods.c
Expand Up @@ -2829,7 +2829,7 @@ array_class_getitem(PyObject *cls, PyObject *args)
Py_ssize_t args_len;

args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
if (args_len != 2) {
if ((args_len > 2) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > 2 ? "many" : "few",
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.c.src
Expand Up @@ -1855,7 +1855,7 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
}

args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
if (args_len != args_len_expected) {
if ((args_len > args_len_expected) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > args_len_expected ? "many" : "few",
Expand Down
27 changes: 13 additions & 14 deletions numpy/core/tests/test_arraymethod.py
Expand Up @@ -3,9 +3,11 @@
this is private API, but when added, public API may be added here.
"""

from __future__ import annotations

import sys
import types
from typing import Any, Type
from typing import Any

import pytest

Expand Down Expand Up @@ -63,28 +65,25 @@ def test_invalid_arguments(self, args, error):


@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
@pytest.mark.parametrize(
"cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
)
class TestClassGetItem:
@pytest.mark.parametrize(
"cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
)
def test_class_getitem(self, cls: Type[np.ndarray]) -> None:
def test_class_getitem(self, cls: type[np.ndarray]) -> None:
"""Test `ndarray.__class_getitem__`."""
alias = cls[Any, Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls

@pytest.mark.parametrize("arg_len", range(4))
def test_subscript_tuple(self, arg_len: int) -> None:
def test_subscript_tup(self, cls: type[np.ndarray], arg_len: int) -> None:
arg_tup = (Any,) * arg_len
if arg_len == 2:
assert np.ndarray[arg_tup]
if arg_len in (1, 2):
assert cls[arg_tup]
else:
with pytest.raises(TypeError):
np.ndarray[arg_tup]

def test_subscript_scalar(self) -> None:
with pytest.raises(TypeError):
np.ndarray[Any]
match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
with pytest.raises(TypeError, match=match):
cls[arg_tup]


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
Expand Down
10 changes: 10 additions & 0 deletions numpy/core/tests/test_scalar_methods.py
Expand Up @@ -153,6 +153,16 @@ def test_abc_complexfloating(self) -> None:
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is np.complexfloating

@pytest.mark.parametrize("arg_len", range(4))
def test_abc_complexfloating_subscript_tuple(self, arg_len: int) -> None:
arg_tup = (Any,) * arg_len
if arg_len in (1, 2):
assert np.complexfloating[arg_tup]
else:
match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
with pytest.raises(TypeError, match=match):
np.complexfloating[arg_tup]

@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
with pytest.raises(TypeError):
Expand Down

0 comments on commit 754ec89

Please sign in to comment.