Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow np.uint64 to be used in indexing. Support numpy 1.24.1 #510

Merged
merged 1 commit into from Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas-stubs/_typing.pyi
Expand Up @@ -189,7 +189,7 @@ np_ndarray_anyint: TypeAlias = npt.NDArray[np.integer]
np_ndarray_bool: TypeAlias = npt.NDArray[np.bool_]
np_ndarray_str: TypeAlias = npt.NDArray[np.str_]

IndexType: TypeAlias = Union[slice, np_ndarray_int64, Index, list[int], Series[int]]
IndexType: TypeAlias = Union[slice, np_ndarray_anyint, Index, list[int], Series[int]]
MaskType: TypeAlias = Union[Series[bool], np_ndarray_bool, list[bool]]
# Scratch types for generics
S1 = TypeVar(
Expand Down
5 changes: 3 additions & 2 deletions pandas-stubs/core/indexes/base.pyi
Expand Up @@ -38,6 +38,7 @@ from pandas._typing import (
Level,
NaPosition,
Scalar,
np_ndarray_anyint,
np_ndarray_bool,
np_ndarray_int64,
type_t,
Expand Down Expand Up @@ -192,10 +193,10 @@ class Index(IndexOpsMixin, PandasObject):
@overload
def __getitem__(
self: IndexT,
idx: slice | np_ndarray_int64 | Index | Series[bool] | np_ndarray_bool,
idx: slice | np_ndarray_anyint | Index | Series[bool] | np_ndarray_bool,
) -> IndexT: ...
@overload
def __getitem__(self, idx: int | tuple[np_ndarray_int64, ...]) -> Scalar: ...
def __getitem__(self, idx: int | tuple[np_ndarray_anyint, ...]) -> Scalar: ...
def append(self, other): ...
def putmask(self, mask, value): ...
def equals(self, other) -> bool: ...
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -42,7 +42,7 @@ pyright = ">=1.1.286"
poethepoet = ">=0.16.5"
loguru = ">=0.6.0"
pandas = "1.5.2"
numpy = "<=1.23.5"
numpy = ">=1.24.1"
typing-extensions = ">=4.2.0"
matplotlib = ">=3.5.1"
pre-commit = ">=2.19.0"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_frame.py
Expand Up @@ -25,6 +25,7 @@
)

import numpy as np
import numpy.typing as npt
import pandas as pd
from pandas._testing import (
ensure_clean,
Expand Down Expand Up @@ -2363,3 +2364,16 @@ def test_frame_dropna_subset() -> None:
assert_type(df.dropna(subset=df.columns.drop("col1")), pd.DataFrame),
pd.DataFrame,
)


def test_npint_loc_indexer() -> None:
# GH 508

df = pd.DataFrame(dict(x=[1, 2, 3]), index=np.array([10, 20, 30], dtype="uint64"))

def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame:
df2 = df.loc[key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't any np.NDArray work (not just integer) as long as the DataFrame index is of the same type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably only enforce a tight dtype match if index was generic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't any np.NDArray work (not just integer) as long as the DataFrame index is of the same type?

Probably, but since we can't track the dtype of an Index in a DataFrame, I'm limiting this for now to the issue that was reported. I think most people use arrays of int or arrays of str (which we probably could add), but I'd rather be incremental in adding support for all the possible types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably only enforce a tight dtype match if index was generic

If we knew the dtype of the underlying index, but we don't know that.

return df2

a: npt.NDArray[np.uint64] = np.array([10, 30], dtype="uint64")
check(assert_type(get_NDArray(df, a), pd.DataFrame), pd.DataFrame)
4 changes: 2 additions & 2 deletions tests/test_pandas.py
Expand Up @@ -13,6 +13,7 @@
import pandas as pd
from pandas import Grouper
from pandas.api.extensions import ExtensionArray
from pandas.util.version import Version
import pytest
from typing_extensions import assert_type

Expand Down Expand Up @@ -1705,7 +1706,7 @@ def test_pivot_table() -> None:
),
pd.DataFrame,
)
with pytest.warns(np.VisibleDeprecationWarning):
if Version(np.__version__) <= Version("1.23.5"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be fine removing tests for older versions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this test once the bug is fixed in pandas. There is a PR for that now at pandas-dev/pandas#50682

check(
assert_type(
pd.pivot_table(
Expand All @@ -1719,7 +1720,6 @@ def test_pivot_table() -> None:
),
pd.DataFrame,
)
with pytest.warns(np.VisibleDeprecationWarning):
check(
assert_type(
pd.pivot_table(
Expand Down