From e838bc708678d8d3fd2637b35b4dd2e316435b66 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sun, 17 Jul 2022 13:59:15 +0800 Subject: [PATCH] [WIP] [pyspark] Cleanup data processing. - Use numpy stack for handling list of arrays. - Reuse concat function from dask. - Prepare for `QuantileDMatrix`. --- python-package/xgboost/compat.py | 39 ++- python-package/xgboost/dask.py | 41 +--- python-package/xgboost/spark/core.py | 28 +-- python-package/xgboost/spark/data.py | 342 +++++++++++++-------------- tests/python/test_spark/test_data.py | 168 ------------- 5 files changed, 209 insertions(+), 409 deletions(-) delete mode 100644 tests/python/test_spark/test_data.py diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 63f9137e67c7..2844ef0cd9b1 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -1,13 +1,14 @@ -# coding: utf-8 # pylint: disable= invalid-name, unused-import """For compatibility and optional dependencies.""" -from typing import Any, Type, Dict, Optional, List +from typing import Any, Type, Dict, Optional, List, Sequence, cast import sys import types import importlib.util import logging import numpy as np +from ._typing import _T + assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' @@ -16,7 +17,7 @@ def py_str(x: bytes) -> str: return x.decode('utf-8') # type: ignore -def lazy_isinstance(instance: Type[object], module: str, name: str) -> bool: +def lazy_isinstance(instance: Any, module: str, name: str) -> bool: """Use string representation to identify a type.""" # Notice, we use .__class__ as opposed to type() in order @@ -111,6 +112,38 @@ def from_json(self, doc: Dict) -> None: SCIPY_INSTALLED = False +def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements + """Concatenate row-wise.""" + if isinstance(value[0], np.ndarray): + return np.concatenate(value, axis=0) + if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix): + return scipy_sparse.vstack(value, format="csr") + if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix): + return scipy_sparse.vstack(value, format="csc") + if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix): + # other sparse format will be converted to CSR. + return scipy_sparse.vstack(value, format="csr") + if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)): + return pandas_concat(value, axis=0) + if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance( + value[0], "cudf.core.series", "Series" + ): + from cudf import concat as CUDF_concat # pylint: disable=import-error + + return CUDF_concat(value, axis=0) + if lazy_isinstance(value[0], "cupy._core.core", "ndarray"): + import cupy + + # pylint: disable=c-extension-no-member,no-member + d = cupy.cuda.runtime.getDevice() + for v in value: + arr = cast(cupy.ndarray, v) + d_v = arr.device.id + assert d_v == d, "Concatenating arrays on different devices." + return cupy.concatenate(value, axis=0) + raise TypeError("Unknown type.") + + # Modified from tensorflow with added caching. There's a `LazyLoader` in # `importlib.utils`, except it's unclear from its document on how to use it. This one # seems to be easy to understand and works out of box. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 510487edb1e8..df70ec753970 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -50,11 +50,10 @@ from .callback import TrainingCallback from .compat import LazyLoader -from .compat import scipy_sparse -from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat +from .compat import DataFrame, concat from .compat import lazy_isinstance -from ._typing import FeatureNames, FeatureTypes +from ._typing import FeatureNames, FeatureTypes, _T from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .core import Objective, Metric @@ -207,35 +206,11 @@ def __init__(self, args: List[bytes]) -> None: ) -def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements - """To be replaced with dask builtin.""" - if isinstance(value[0], numpy.ndarray): - return numpy.concatenate(value, axis=0) - if scipy_sparse and isinstance(value[0], scipy_sparse.csr_matrix): - return scipy_sparse.vstack(value, format="csr") - if scipy_sparse and isinstance(value[0], scipy_sparse.csc_matrix): - return scipy_sparse.vstack(value, format="csc") - if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix): - # other sparse format will be converted to CSR. - return scipy_sparse.vstack(value, format="csr") - if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)): - return pandas_concat(value, axis=0) - if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance( - value[0], "cudf.core.series", "Series" - ): - from cudf import concat as CUDF_concat # pylint: disable=import-error - - return CUDF_concat(value, axis=0) - if lazy_isinstance(value[0], "cupy._core.core", "ndarray"): - import cupy - - # pylint: disable=c-extension-no-member,no-member - d = cupy.cuda.runtime.getDevice() - for v in value: - d_v = v.device.id - assert d_v == d, "Concatenating arrays on different devices." - return cupy.concatenate(value, axis=0) - return dd.multi.concat(list(value), axis=0) +def dconcat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements + try: + return concat(value) + except TypeError: + return dd.multi.concat(list(value), axis=0) def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Client": @@ -770,7 +745,7 @@ def _create_dmatrix( def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]: if any(part is None for part in data): return None - return concat(data) + return dconcat(data) unzipped_dict = _get_worker_parts(list_of_parts) concated_dict: Dict[str, Any] = {} diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index c38fcbffd48a..0b8353c05b46 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -36,9 +36,7 @@ from xgboost.core import Booster from xgboost.training import train as worker_train -from .data import ( - _convert_partition_data_to_dmatrix, -) +from .data import create_dmatrix_from_partitions from .model import ( SparkXGBReader, SparkXGBWriter, @@ -599,25 +597,13 @@ def _train_booster(pandas_df_iter): _rabit_args = _get_args_from_message_list(messages) evals_result = {} with RabitContext(_rabit_args, context): - dtrain, dval = None, [] - if has_validation: - dtrain, dval = _convert_partition_data_to_dmatrix( - pandas_df_iter, - has_weight, - has_validation, - has_base_margin, - dmatrix_kwargs=dmatrix_kwargs, - ) - # TODO: Question: do we need to add dtrain to dval list ? - dval = [(dtrain, "training"), (dval, "validation")] + dtrain, dvalid = create_dmatrix_from_partitions( + pandas_df_iter, has_validation, + ) + if dvalid: + dval = [(dtrain, "training"), (dvalid, "validation")] else: - dtrain = _convert_partition_data_to_dmatrix( - pandas_df_iter, - has_weight, - has_validation, - has_base_margin, - dmatrix_kwargs=dmatrix_kwargs, - ) + dval = None booster = worker_train( params=booster_params, diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 16fe038edf15..2e31ee54f905 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -1,192 +1,166 @@ -# type: ignore -"""Xgboost pyspark integration submodule for data related functions.""" -# pylint: disable=too-many-arguments -from typing import Iterator +from typing import Iterator, List, Sequence, Optional, Dict, Tuple, Callable +from collections import defaultdict, namedtuple import numpy as np import pandas as pd -from xgboost import DMatrix - - -def _prepare_train_val_data( - data_iterator, has_weight, has_validation, has_fit_base_margin -): - def gen_data_pdf(): - for pdf in data_iterator: - yield pdf - - return _process_data_iter( - gen_data_pdf(), - train=True, - has_weight=has_weight, - has_validation=has_validation, - has_fit_base_margin=has_fit_base_margin, - has_predict_base_margin=False, - ) - - -def _check_feature_dims(num_dims, expected_dims): - """ - Check all feature vectors has the same dimension - """ - if expected_dims is None: - return num_dims - if num_dims != expected_dims: - raise ValueError( - f"Rows contain different feature dimensions: Expecting {expected_dims}, got {num_dims}." - ) - return expected_dims - - -def _row_tuple_list_to_feature_matrix_y_w( - data_iterator, - train, - has_weight, - has_fit_base_margin, - has_predict_base_margin, - has_validation: bool = False, -): - """ - Construct a feature matrix in ndarray format, label array y and weight array w - from the row_tuple_list. - If train == False, y and w will be None. - If has_weight == False, w will be None. - If has_base_margin == False, b_m will be None. - Note: the row_tuple_list will be cleared during - executing for reducing peak memory consumption - """ - # pylint: disable=too-many-locals - expected_feature_dims = None - label_list, weight_list, base_margin_list = [], [], [] - label_val_list, weight_val_list, base_margin_val_list = [], [], [] - values_list, values_val_list = [], [] - - # Process rows - for pdf in data_iterator: - if len(pdf) == 0: - continue - if train and has_validation: - pdf_val = pdf.loc[pdf["validationIndicator"], :] - pdf = pdf.loc[~pdf["validationIndicator"], :] - - num_feature_dims = len(pdf["values"].values[0]) - - expected_feature_dims = _check_feature_dims( - num_feature_dims, expected_feature_dims - ) +from xgboost import DMatrix, DataIter, DeviceQuantileDMatrix +from xgboost.compat import concat, lazy_isinstance + + +def stack_df(df: pd.DataFrame) -> np.ndarray: + array = df.values + if array.ndim == 1: + array = array.reshape(array.shape[0], 1) + array = np.stack(array[:, 0]) + return array + + +def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]: + if seq: + return concat(seq) + return None + + +Alias = namedtuple("Alias", ("data", "label", "weight", "margin", "valid")) +alias = Alias("values", "label", "weight", "baseMargin", "validationIndicator") - # Note: each element in `pdf["values"]` is an numpy array. - values_list.append(pdf["values"].to_list()) - if train: - label_list.append(pdf["label"].to_numpy()) - if has_weight: - weight_list.append(pdf["weight"].to_numpy()) - if has_fit_base_margin or has_predict_base_margin: - base_margin_list.append(pdf["baseMargin"].to_numpy()) + +def create_dmatrix_from_partitions( + iterator: Iterator[pd.DataFrame], has_validation: bool +) -> Tuple[DMatrix, Optional[DMatrix]]: + train_data: Dict[str, List[np.ndarray]] = defaultdict(list) + valid_data: Dict[str, List[np.ndarray]] = defaultdict(list) + n_features: List[int] = [0] + + def append(part: pd.DataFrame, name: str, is_valid: bool) -> None: + if name in part.columns: + array = part[name] + if name == alias.data: + array = stack_df(array) + if n_features[0] == 0: + n_features[0] = array.shape[1] + assert n_features[0] == array.shape[1] + + if is_valid: + valid_data[name].append(array) + else: + train_data[name].append(array) + + def make_blob(part: pd.DataFrame, is_valid: bool) -> None: + append(part, alias.data, is_valid) + append(part, alias.label, is_valid) + append(part, alias.weight, is_valid) + append(part, alias.margin, is_valid) + + for part in iterator: if has_validation: - values_val_list.append(pdf_val["values"].to_list()) - if train: - label_val_list.append(pdf_val["label"].to_numpy()) - if has_weight: - weight_val_list.append(pdf_val["weight"].to_numpy()) - if has_fit_base_margin or has_predict_base_margin: - base_margin_val_list.append(pdf_val["baseMargin"].to_numpy()) - - # Construct feature_matrix - if expected_feature_dims is None: - return [], [], [], [] - - # Construct feature_matrix, y and w - feature_matrix = np.concatenate(values_list) - y = np.concatenate(label_list) if train else None - w = np.concatenate(weight_list) if has_weight else None - b_m = ( - np.concatenate(base_margin_list) - if (has_fit_base_margin or has_predict_base_margin) - else None - ) - if has_validation: - feature_matrix_val = np.concatenate(values_val_list) - y_val = np.concatenate(label_val_list) if train else None - w_val = np.concatenate(weight_val_list) if has_weight else None - b_m_val = ( - np.concatenate(base_margin_val_list) - if (has_fit_base_margin or has_predict_base_margin) - else None + train = part.loc[~part[alias.valid], :] + valid = part.loc[part[alias.valid], :] + else: + train, valid = part, None + + make_blob(train, False) + if valid is not None: + make_blob(valid, True) + + def make(values: Dict[str, List[np.ndarray]]) -> DMatrix: + data = concat_or_none(train_data[alias.data]) + label = concat_or_none(train_data[alias.label]) + weight = concat_or_none(train_data[alias.weight]) + margin = concat_or_none(train_data[alias.margin]) + return DMatrix(data=data, label=label, weight=weight, base_margin=margin) + + train = make(train_data) + valid = make(valid_data) if has_validation else None + return train, valid + + +class PartIter(DataIter): + def __init__( + self, + it: Iterator[pd.DataFrame], + feature_cols: Sequence[str], + on_device: bool, + has_validation: bool, + ) -> None: + self._iter = 0 + self.train_data: Dict[str, List] = defaultdict(list) + valid_data: Dict[str, List] = defaultdict(list) + + def append(part: pd.DataFrame, name: str) -> None: + if name == alias.data or name in part.columns: + if name == alias.data: + fname = feature_cols + else: + fname = name + if has_validation: + train = part.loc[~part[alias.valid], :][fname] + valid = part.loc[part[alias.valid], :][fname] + valid_data[name].append(valid) + else: + train = part[fname] + valid = None + self.train_data[name].append(train) + + self._cuda = on_device + + for part in it: + append(part, alias.data) + append(part, alias.label) + append(part, alias.weight) + append(part, alias.margin) + + if valid_data: + c_valid_data = concat_or_none(valid_data[alias.data]) + c_valid_label = concat_or_none(valid_data.get(alias.label, None)) + c_valid_weight = concat_or_none(valid_data.get(alias.weight, None)) + c_valid_margin = concat_or_none(valid_data.get(alias.margin, None)) + self.dvalid: Optional[DMatrix] = DMatrix( + data=c_valid_data, + label=c_valid_label, + weight=c_valid_weight, + base_margin=c_valid_margin, + ) + else: + self.dvalid = None + + super().__init__() + + def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]: + if not data: + return None + + if self._cuda: + import cudf + + return cudf.DataFrame(data[self._iter]) + else: + return data[self._iter] + + def next(self, input_data: Callable) -> int: + if self._iter == len(self.train_data[alias.data]): + return 0 + input_data( + data=self._fetch(self.train_data[alias.data]), + label=self._fetch(self.train_data.get(alias.label, None)), + weight=self._fetch(self.train_data.get(alias.weight, None)), + base_margin=self._fetch(self.train_data.get(alias.margin, None)), ) - return feature_matrix, y, w, b_m, feature_matrix_val, y_val, w_val, b_m_val - return feature_matrix, y, w, b_m + self._iter += 1 + return 1 + + def reset(self) -> None: + self._iter = 0 -def _process_data_iter( - data_iterator: Iterator[pd.DataFrame], - train: bool, - has_weight: bool, +def create_qdm_from_partitions( + iterator: Iterator[pd.DataFrame], has_validation: bool, - has_fit_base_margin: bool = False, - has_predict_base_margin: bool = False, -): - """ - If input is for train and has_validation=True, it will split the train data into train dataset - and validation dataset, and return (train_X, train_y, train_w, train_b_m <- - train base margin, val_X, val_y, val_w, val_b_m <- validation base margin) - otherwise return (X, y, w, b_m <- base margin) - """ - return _row_tuple_list_to_feature_matrix_y_w( - data_iterator, - train, - has_weight, - has_fit_base_margin, - has_predict_base_margin, - has_validation, - ) - - -def _convert_partition_data_to_dmatrix( - partition_data_iter, - has_weight, - has_validation, - has_base_margin, - dmatrix_kwargs=None, -): - # pylint: disable=too-many-locals, unbalanced-tuple-unpacking - dmatrix_kwargs = dmatrix_kwargs or {} - # if we are not using external storage, we use the standard method of parsing data. - train_val_data = _prepare_train_val_data( - partition_data_iter, has_weight, has_validation, has_base_margin - ) - if has_validation: - ( - train_x, - train_y, - train_w, - train_b_m, - val_x, - val_y, - val_w, - val_b_m, - ) = train_val_data - training_dmatrix = DMatrix( - data=train_x, - label=train_y, - weight=train_w, - base_margin=train_b_m, - **dmatrix_kwargs, - ) - val_dmatrix = DMatrix( - data=val_x, - label=val_y, - weight=val_w, - base_margin=val_b_m, - **dmatrix_kwargs, - ) - return training_dmatrix, val_dmatrix - - train_x, train_y, train_w, train_b_m = train_val_data - training_dmatrix = DMatrix( - data=train_x, - label=train_y, - weight=train_w, - base_margin=train_b_m, - **dmatrix_kwargs, - ) - return training_dmatrix + features_cols: Sequence[str] = [], +) -> Tuple[DMatrix, Optional[DMatrix]]: + if not features_cols: + # features cols must not be empty + raise ValueError() + # fixme: always on device when this is used. + it = PartIter(iterator, features_cols, True, has_validation) + dvalid = it.dvalid + return DeviceQuantileDMatrix(it), dvalid diff --git a/tests/python/test_spark/test_data.py b/tests/python/test_spark/test_data.py deleted file mode 100644 index 9b6aa1b72305..000000000000 --- a/tests/python/test_spark/test_data.py +++ /dev/null @@ -1,168 +0,0 @@ -import sys -import tempfile -import shutil - -import pytest -import numpy as np -import pandas as pd - -import testing as tm - -if tm.no_spark()["condition"]: - pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) -if sys.platform.startswith("win") or sys.platform.startswith("darwin"): - pytest.skip("Skipping PySpark tests on Windows", allow_module_level=True) - -from xgboost.spark.data import ( - _row_tuple_list_to_feature_matrix_y_w, - _convert_partition_data_to_dmatrix, -) - -from xgboost import DMatrix, XGBClassifier -from xgboost.training import train as worker_train -from .utils import SparkTestCase -import logging - -logging.getLogger("py4j").setLevel(logging.INFO) - - -class DataTest(SparkTestCase): - def test_sparse_dense_vector(self): - def row_tup_iter(data): - pdf = pd.DataFrame(data) - yield pdf - - expected_ndarray = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) - data = {"values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]} - feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - list(row_tup_iter(data)), - train=False, - has_weight=False, - has_fit_base_margin=False, - has_predict_base_margin=False, - ) - self.assertIsNone(y) - self.assertIsNone(w) - self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) - - data["label"] = [1, 0] - feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - row_tup_iter(data), - train=True, - has_weight=False, - has_fit_base_margin=False, - has_predict_base_margin=False, - ) - self.assertIsNone(w) - self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) - self.assertTrue(np.array_equal(y, np.array(data["label"]))) - - data["weight"] = [0.2, 0.8] - feature_matrix, y, w, _ = _row_tuple_list_to_feature_matrix_y_w( - list(row_tup_iter(data)), - train=True, - has_weight=True, - has_fit_base_margin=False, - has_predict_base_margin=False, - ) - self.assertTrue(np.allclose(feature_matrix, expected_ndarray)) - self.assertTrue(np.array_equal(y, np.array(data["label"]))) - self.assertTrue(np.array_equal(w, np.array(data["weight"]))) - - def test_dmatrix_creator(self): - - # This function acts as a pseudo-itertools.chain() - def row_tup_iter(data): - pdf = pd.DataFrame(data) - yield pdf - - # Standard testing DMatrix creation - expected_features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100) - expected_labels = np.array([1, 0] * 100) - expected_dmatrix = DMatrix(data=expected_features, label=expected_labels) - - data = { - "values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, - "label": [1, 0] * 100, - } - output_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], - has_weight=False, - has_validation=False, - has_base_margin=False, - ) - # You can't compare DMatrix outputs, so the only way is to predict on the two seperate DMatrices using - # the same classifier and making sure the outputs are equal - model = XGBClassifier() - model.fit(expected_features, expected_labels) - expected_preds = model.get_booster().predict(expected_dmatrix) - output_preds = model.get_booster().predict(output_dmatrix) - self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3)) - - # DMatrix creation with weights - expected_weight = np.array([0.2, 0.8] * 100) - expected_dmatrix = DMatrix( - data=expected_features, label=expected_labels, weight=expected_weight - ) - - data["weight"] = [0.2, 0.8] * 100 - output_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], - has_weight=True, - has_validation=False, - has_base_margin=False, - ) - - model.fit(expected_features, expected_labels, sample_weight=expected_weight) - expected_preds = model.get_booster().predict(expected_dmatrix) - output_preds = model.get_booster().predict(output_dmatrix) - self.assertTrue(np.allclose(expected_preds, output_preds, atol=1e-3)) - - def test_external_storage(self): - # Instantiating base data (features, labels) - features = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100) - labels = np.array([1, 0] * 100) - normal_dmatrix = DMatrix(features, labels) - test_dmatrix = DMatrix(features) - - data = { - "values": [[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]] * 100, - "label": [1, 0] * 100, - } - - # Creating the dmatrix based on storage - temporary_path = tempfile.mkdtemp() - storage_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], - has_weight=False, - has_validation=False, - has_base_margin=False, - ) - - # Testing without weights - normal_booster = worker_train({}, normal_dmatrix) - storage_booster = worker_train({}, storage_dmatrix) - normal_preds = normal_booster.predict(test_dmatrix) - storage_preds = storage_booster.predict(test_dmatrix) - self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) - shutil.rmtree(temporary_path) - - # Testing weights - weights = np.array([0.2, 0.8] * 100) - normal_dmatrix = DMatrix(data=features, label=labels, weight=weights) - data["weight"] = [0.2, 0.8] * 100 - - temporary_path = tempfile.mkdtemp() - storage_dmatrix = _convert_partition_data_to_dmatrix( - [pd.DataFrame(data)], - has_weight=True, - has_validation=False, - has_base_margin=False, - ) - - normal_booster = worker_train({}, normal_dmatrix) - storage_booster = worker_train({}, storage_dmatrix) - normal_preds = normal_booster.predict(test_dmatrix) - storage_preds = storage_booster.predict(test_dmatrix) - self.assertTrue(np.allclose(normal_preds, storage_preds, atol=1e-3)) - shutil.rmtree(temporary_path)