Skip to content

Commit

Permalink
Upgrade mypy. (#8302)
Browse files Browse the repository at this point in the history
Some breaking changes were made in mypy.
  • Loading branch information
trivialfis committed Oct 5, 2022
1 parent 97c3a80 commit e47b3a3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
63 changes: 45 additions & 18 deletions python-package/xgboost/_typing.py
@@ -1,7 +1,18 @@
# pylint: disable=protected-access
"""Shared typing definition."""
import ctypes
import os
from typing import Any, Callable, Dict, List, Sequence, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Sequence,
Type,
TypeVar,
Union,
)

# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack
Expand Down Expand Up @@ -32,14 +43,15 @@
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103

CTypeT = Union[
CTypeT = TypeVar(
"CTypeT",
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_int,
ctypes.c_float,
ctypes.c_uint,
ctypes.c_size_t,
]
)

# supported numeric types
CNumeric = Union[
Expand All @@ -52,21 +64,36 @@
]

# c pointer types
# real type should be, as defined in typeshed
# but this has to be put in a .pyi file
# c_str_ptr_t = ctypes.pointer[ctypes.c_char]
CStrPtr = ctypes.pointer
# c_str_pptr_t = ctypes.pointer[ctypes.c_char_p]
CStrPptr = ctypes.pointer
# c_float_ptr_t = ctypes.pointer[ctypes.c_float]
CFloatPtr = ctypes.pointer

# c_numeric_ptr_t = Union[
# ctypes.pointer[ctypes.c_float], ctypes.pointer[ctypes.c_double],
# ctypes.pointer[ctypes.c_uint], ctypes.pointer[ctypes.c_uint64],
# ctypes.pointer[ctypes.c_int32], ctypes.pointer[ctypes.c_int64]
# ]
CNumericPtr = ctypes.pointer
if TYPE_CHECKING:
CStrPtr = ctypes._Pointer[ctypes.c_char]

CStrPptr = ctypes._Pointer[ctypes.c_char_p]

CFloatPtr = ctypes._Pointer[ctypes.c_float]

CNumericPtr = Union[
ctypes._Pointer[ctypes.c_float],
ctypes._Pointer[ctypes.c_double],
ctypes._Pointer[ctypes.c_uint],
ctypes._Pointer[ctypes.c_uint64],
ctypes._Pointer[ctypes.c_int32],
ctypes._Pointer[ctypes.c_int64],
]
else:
CStrPtr = ctypes._Pointer

CStrPptr = ctypes._Pointer

CFloatPtr = ctypes._Pointer

CNumericPtr = Union[
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
ctypes._Pointer,
]

# template parameter
_T = TypeVar("_T")
Expand Down
10 changes: 5 additions & 5 deletions python-package/xgboost/core.py
Expand Up @@ -99,9 +99,9 @@ def from_cstr_to_pystr(data: CStrPptr, length: c_bst_ulong) -> List[str]:
res = []
for i in range(length.value):
try:
res.append(str(data[i].decode('ascii'))) # type: ignore
res.append(str(cast(bytes, data[i]).decode('ascii')))
except UnicodeDecodeError:
res.append(str(data[i].decode('utf-8'))) # type: ignore
res.append(str(cast(bytes, data[i]).decode('utf-8')))
return res


Expand Down Expand Up @@ -381,7 +381,7 @@ def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
raise RuntimeError('expected char pointer')
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length): # type: ignore
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res

Expand All @@ -393,8 +393,8 @@ def c_str(string: str) -> ctypes.c_char_p:

def c_array(
ctype: Type[CTypeT], values: ArrayLike
) -> Union[ctypes.Array, ctypes.pointer]:
"""Convert a python string to c array."""
) -> Union[ctypes.Array, ctypes._Pointer]:
"""Convert a python array to c array."""
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
return values.ctypes.data_as(ctypes.POINTER(ctype))
return (ctype * len(values))(*values)
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/training.py
Expand Up @@ -13,7 +13,7 @@
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .core import Metric, Objective
from .compat import SKLEARN_INSTALLED, XGBStratifiedKFold, DataFrame
from ._typing import _F, FPreProcCallable, BoosterParam
from ._typing import Callable, FPreProcCallable, BoosterParam

_CVFolds = Sequence["CVPack"]

Expand Down Expand Up @@ -205,10 +205,10 @@ def __init__(self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict,
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain, dtest])

def __getattr__(self, name: str) -> _F:
def __getattr__(self, name: str) -> Callable:
def _inner(*args: Any, **kwargs: Any) -> Any:
return getattr(self.bst, name)(*args, **kwargs)
return cast(_F, _inner)
return _inner

def update(self, iteration: int, fobj: Optional[Objective]) -> None:
""""Update the boosters for one iteration"""
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/python_lint.yml
Expand Up @@ -6,7 +6,7 @@ dependencies:
- pylint
- wheel
- setuptools
- mypy=0.961
- mypy>=0.981
- numpy
- scipy
- pandas
Expand Down

0 comments on commit e47b3a3

Please sign in to comment.