Skip to content

Commit

Permalink
Remove STRING_TYPES. (#7827)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 22, 2022
1 parent c13a2a3 commit f0f7625
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 14 deletions.
3 changes: 1 addition & 2 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from . import rabit
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .compat import STRING_TYPES


__all__ = [
Expand Down Expand Up @@ -82,7 +81,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
results = []
for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
as_arr = numpy.array(s)
if not isinstance(msg, STRING_TYPES):
if not isinstance(msg, str):
msg = msg.decode()
mean, std = numpy.mean(as_arr), numpy.std(as_arr)
results.extend([(name, mean, std)])
Expand Down
3 changes: 0 additions & 3 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'

# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = (str,)


def py_str(x):
"""convert c string back to python string"""
Expand Down
14 changes: 7 additions & 7 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import scipy.sparse

from .compat import STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED
from .compat import DataFrame, py_str, PANDAS_INSTALLED
from .libpath import find_lib_path
from ._typing import (
CStrPptr,
Expand Down Expand Up @@ -1387,7 +1387,7 @@ def __init__(
_check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os.PathLike, bytearray)):
elif isinstance(model_file, (str, os.PathLike, bytearray)):
self.load_model(model_file)
elif model_file is None:
pass
Expand Down Expand Up @@ -1629,7 +1629,7 @@ def set_attr(self, **kwargs: Optional[str]) -> None:
"""
for key, value in kwargs.items():
if value is not None:
if not isinstance(value, STRING_TYPES):
if not isinstance(value, str):
raise ValueError("Set Attr only accepts string values")
value = c_str(str(value))
_check_call(_LIB.XGBoosterSetAttr(
Expand Down Expand Up @@ -1705,7 +1705,7 @@ def set_param(
"""
if isinstance(params, Mapping):
params = params.items()
elif isinstance(params, STRING_TYPES) and value is not None:
elif isinstance(params, str) and value is not None:
params = [(params, value)]
for key, val in params:
if val is not None:
Expand Down Expand Up @@ -1796,7 +1796,7 @@ def eval_set(
for d in evals:
if not isinstance(d[0], DMatrix):
raise TypeError(f"expected DMatrix, got {type(d[0]).__name__}")
if not isinstance(d[1], STRING_TYPES):
if not isinstance(d[1], str):
raise TypeError(f"expected string, got {type(d[1]).__name__}")
self._validate_features(d[0])

Expand Down Expand Up @@ -2192,7 +2192,7 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None:
Output file name
"""
if isinstance(fname, (STRING_TYPES, os.PathLike)): # assume file name
if isinstance(fname, (str, os.PathLike)): # assume file name
fname = os.fspath(os.path.expanduser(fname))
_check_call(_LIB.XGBoosterSaveModel(
self.handle, c_str(fname)))
Expand Down Expand Up @@ -2301,7 +2301,7 @@ def dump_model(self, fout: Union[str, os.PathLike], fmap: Union[str, os.PathLike
dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'.
"""
if isinstance(fout, (STRING_TYPES, os.PathLike)):
if isinstance(fout, (str, os.PathLike)):
fout = os.fspath(os.path.expanduser(fout))
# pylint: disable=consider-using-with
fout_obj = open(fout, 'w', encoding="utf-8")
Expand Down
4 changes: 2 additions & 2 deletions python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .core import _LIB, c_str, STRING_TYPES, _check_call
from .core import _LIB, c_str, _check_call


def _init_rabit() -> None:
Expand Down Expand Up @@ -73,7 +73,7 @@ def tracker_print(msg: Any) -> None:
msg : str
The message to be printed to tracker.
"""
if not isinstance(msg, STRING_TYPES):
if not isinstance(msg, str):
msg = str(msg)
is_dist = _LIB.RabitIsDistributed()
if is_dist != 0:
Expand Down

0 comments on commit f0f7625

Please sign in to comment.