Skip to content

How to narrow numpy array dtype in if statements #7830

Answered by erictraut
gaufde asked this question in Q&A
Discussion options

You must be logged in to vote

The Python static type system isn't rich enough to express the notion that "this object has an attribute called dtype that contains an immutable value that relates to the type of class". A static type checker therefore cannot narrow the type of arr in your example without additional information.

The TypeGuard type (introduced in PEP 647) was designed to handle cases like this.

from typing import Any, TypeGuard, TypeVar

T = TypeVar("T", bound=np.generic)

def is_correct_type(arr: npt.NDArray[Any], typ: type[T]) -> TypeGuard[npt.NDArray[T]]:
    return arr.dtype == np.int_ or np.issubdtype(arr.dtype, np.int_)

if is_correct_type(arr, np.int_):
    arr = func_2(arr)

Or if you're using Pytho…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@gaufde
Comment options

Answer selected by gaufde
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants