Skip to content

Commit

Permalink
[WIP] [pyspark] Cleanup data processing.
Browse files Browse the repository at this point in the history
- Use numpy stack for handling list of arrays.
- Reuse concat function from dask.
- Prepare for `QuantileDMatrix`.
  • Loading branch information
trivialfis committed Jul 17, 2022
1 parent e28f6f6 commit e838bc7
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 409 deletions.
39 changes: 36 additions & 3 deletions python-package/xgboost/compat.py
@@ -1,13 +1,14 @@
# coding: utf-8
# pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies."""
from typing import Any, Type, Dict, Optional, List
from typing import Any, Type, Dict, Optional, List, Sequence, cast
import sys
import types
import importlib.util
import logging
import numpy as np

from ._typing import _T

assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'


Expand All @@ -16,7 +17,7 @@ def py_str(x: bytes) -> str:
return x.decode('utf-8') # type: ignore


def lazy_isinstance(instance: Type[object], module: str, name: str) -> bool:
def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
"""Use string representation to identify a type."""

# Notice, we use .__class__ as opposed to type() in order
Expand Down Expand Up @@ -111,6 +112,38 @@ def from_json(self, doc: Dict) -> None:
SCIPY_INSTALLED = False


def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
"""Concatenate row-wise."""
if isinstance(value[0], np.ndarray):
return np.concatenate(value, axis=0)
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
return scipy_sparse.vstack(value, format="csr")
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
return scipy_sparse.vstack(value, format="csc")
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
# other sparse format will be converted to CSR.
return scipy_sparse.vstack(value, format="csr")
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
return pandas_concat(value, axis=0)
if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance(
value[0], "cudf.core.series", "Series"
):
from cudf import concat as CUDF_concat # pylint: disable=import-error

return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], "cupy._core.core", "ndarray"):
import cupy

# pylint: disable=c-extension-no-member,no-member
d = cupy.cuda.runtime.getDevice()
for v in value:
arr = cast(cupy.ndarray, v)
d_v = arr.device.id
assert d_v == d, "Concatenating arrays on different devices."
return cupy.concatenate(value, axis=0)
raise TypeError("Unknown type.")


# Modified from tensorflow with added caching. There's a `LazyLoader` in
# `importlib.utils`, except it's unclear from its document on how to use it. This one
# seems to be easy to understand and works out of box.
Expand Down
41 changes: 8 additions & 33 deletions python-package/xgboost/dask.py
Expand Up @@ -50,11 +50,10 @@
from .callback import TrainingCallback

from .compat import LazyLoader
from .compat import scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import DataFrame, concat
from .compat import lazy_isinstance

from ._typing import FeatureNames, FeatureTypes
from ._typing import FeatureNames, FeatureTypes, _T

from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .core import Objective, Metric
Expand Down Expand Up @@ -207,35 +206,11 @@ def __init__(self, args: List[bytes]) -> None:
)


def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
"""To be replaced with dask builtin."""
if isinstance(value[0], numpy.ndarray):
return numpy.concatenate(value, axis=0)
if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix):
return scipy_sparse.vstack(value, format="csr")
if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix):
return scipy_sparse.vstack(value, format="csc")
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
# other sparse format will be converted to CSR.
return scipy_sparse.vstack(value, format="csr")
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
return pandas_concat(value, axis=0)
if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance(
value[0], "cudf.core.series", "Series"
):
from cudf import concat as CUDF_concat # pylint: disable=import-error

return CUDF_concat(value, axis=0)
if lazy_isinstance(value[0], "cupy._core.core", "ndarray"):
import cupy

# pylint: disable=c-extension-no-member,no-member
d = cupy.cuda.runtime.getDevice()
for v in value:
d_v = v.device.id
assert d_v == d, "Concatenating arrays on different devices."
return cupy.concatenate(value, axis=0)
return dd.multi.concat(list(value), axis=0)
def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
try:
return concat(value)
except TypeError:
return dd.multi.concat(list(value), axis=0)


def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Client":
Expand Down Expand Up @@ -770,7 +745,7 @@ def _create_dmatrix(
def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]:
if any(part is None for part in data):
return None
return concat(data)
return dconcat(data)

unzipped_dict = _get_worker_parts(list_of_parts)
concated_dict: Dict[str, Any] = {}
Expand Down
28 changes: 7 additions & 21 deletions python-package/xgboost/spark/core.py
Expand Up @@ -36,9 +36,7 @@
from xgboost.core import Booster
from xgboost.training import train as worker_train

from .data import (
_convert_partition_data_to_dmatrix,
)
from .data import create_dmatrix_from_partitions
from .model import (
SparkXGBReader,
SparkXGBWriter,
Expand Down Expand Up @@ -599,25 +597,13 @@ def _train_booster(pandas_df_iter):
_rabit_args = _get_args_from_message_list(messages)
evals_result = {}
with RabitContext(_rabit_args, context):
dtrain, dval = None, []
if has_validation:
dtrain, dval = _convert_partition_data_to_dmatrix(
pandas_df_iter,
has_weight,
has_validation,
has_base_margin,
dmatrix_kwargs=dmatrix_kwargs,
)
# TODO: Question: do we need to add dtrain to dval list ?
dval = [(dtrain, "training"), (dval, "validation")]
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter, has_validation,
)
if dvalid:
dval = [(dtrain, "training"), (dvalid, "validation")]
else:
dtrain = _convert_partition_data_to_dmatrix(
pandas_df_iter,
has_weight,
has_validation,
has_base_margin,
dmatrix_kwargs=dmatrix_kwargs,
)
dval = None

booster = worker_train(
params=booster_params,
Expand Down

0 comments on commit e838bc7

Please sign in to comment.