Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP,BUG: Reduce argument validation in C-based __class_getitem__ #22222

Merged
merged 1 commit into from Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/methods.c
Expand Up @@ -2829,7 +2829,7 @@ array_class_getitem(PyObject *cls, PyObject *args)
Py_ssize_t args_len;

args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
if (args_len != 2) {
if ((args_len > 2) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > 2 ? "many" : "few",
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.c.src
Expand Up @@ -1855,7 +1855,7 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
}

args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
if (args_len != args_len_expected) {
if ((args_len > args_len_expected) || (args_len == 0)) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > args_len_expected ? "many" : "few",
Expand Down
27 changes: 13 additions & 14 deletions numpy/core/tests/test_arraymethod.py
Expand Up @@ -3,9 +3,11 @@
this is private API, but when added, public API may be added here.
"""

from __future__ import annotations

import sys
import types
from typing import Any, Type
from typing import Any

import pytest

Expand Down Expand Up @@ -63,28 +65,25 @@ def test_invalid_arguments(self, args, error):


@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
@pytest.mark.parametrize(
"cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
)
class TestClassGetItem:
@pytest.mark.parametrize(
"cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
)
def test_class_getitem(self, cls: Type[np.ndarray]) -> None:
def test_class_getitem(self, cls: type[np.ndarray]) -> None:
"""Test `ndarray.__class_getitem__`."""
alias = cls[Any, Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls

@pytest.mark.parametrize("arg_len", range(4))
def test_subscript_tuple(self, arg_len: int) -> None:
def test_subscript_tup(self, cls: type[np.ndarray], arg_len: int) -> None:
arg_tup = (Any,) * arg_len
if arg_len == 2:
assert np.ndarray[arg_tup]
if arg_len in (1, 2):
assert cls[arg_tup]
else:
with pytest.raises(TypeError):
np.ndarray[arg_tup]

def test_subscript_scalar(self) -> None:
with pytest.raises(TypeError):
np.ndarray[Any]
match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
with pytest.raises(TypeError, match=match):
cls[arg_tup]


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
Expand Down
10 changes: 10 additions & 0 deletions numpy/core/tests/test_scalar_methods.py
Expand Up @@ -153,6 +153,16 @@ def test_abc_complexfloating(self) -> None:
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is np.complexfloating

@pytest.mark.parametrize("arg_len", range(4))
def test_abc_complexfloating_subscript_tuple(self, arg_len: int) -> None:
arg_tup = (Any,) * arg_len
if arg_len in (1, 2):
assert np.complexfloating[arg_tup]
else:
match = f"Too {'few' if arg_len == 0 else 'many'} arguments"
with pytest.raises(TypeError, match=match):
np.complexfloating[arg_tup]

@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
with pytest.raises(TypeError):
Expand Down