Skip to content

Commit

Permalink
Copy ellpack to ghist.
Browse files Browse the repository at this point in the history
Start working on sparse data.

Fix race.

Remove check.

Merge functions.

Cleanup.

Cleanup.

Start writing tests.

Fix.

comp column.

Python test.

lint.

lint.

Fix.

Cleanup.

Avoid binary search.

Use quantile dmatrix by default in sklearn interface.

dask as well.

Fix max_bin.

Fix empty dmatrix for CPU.

Fix GPU version.

Fix empty DMatrix.

pylint.
  • Loading branch information
trivialfis committed Sep 6, 2022
1 parent b5eb36f commit cff8513
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 37 deletions.
25 changes: 18 additions & 7 deletions python-package/xgboost/dask.py
Expand Up @@ -726,10 +726,9 @@ def _create_quantile_dmatrix(
if parts is None:
msg = f"worker {worker.address} has an empty DMatrix."
LOGGER.warning(msg)
import cupy

d = QuantileDMatrix(
cupy.zeros((0, 0)),
numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types,
max_bin=max_bin,
Expand Down Expand Up @@ -1544,15 +1543,21 @@ def inplace_predict( # pylint: disable=unused-argument


async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"], **kwargs: Any
client: Optional["distributed.Client"],
tree_method: Optional[str],
max_bin: Optional[int],
**kwargs: Any,
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""

def _inner(**kwargs: Any) -> DaskDMatrix:
m = DaskDMatrix(client=client, **kwargs)
return m
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if tree_method in ("hist", "gpu_hist"):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)
return DaskDMatrix(client=client, **kwargs)

train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_dispatch, **kwargs)
train_dmatrix = await train_dmatrix
if evals is None:
return train_dmatrix, evals
Expand Down Expand Up @@ -1756,6 +1761,8 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
Expand Down Expand Up @@ -1851,6 +1858,8 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
Expand Down Expand Up @@ -2057,6 +2066,8 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
Expand Down
23 changes: 19 additions & 4 deletions python-package/xgboost/sklearn.py
Expand Up @@ -38,6 +38,7 @@
Booster,
DMatrix,
Metric,
QuantileDMatrix,
XGBoostError,
_convert_ntree_limit,
_deprecate_positional_args,
Expand Down Expand Up @@ -430,7 +431,8 @@ def _wrap_evaluation_matrices(
enable_categorical: bool,
feature_types: Optional[FeatureTypes],
) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way."""
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the
way."""
train_dmatrix = create_dmatrix(
data=X,
label=y,
Expand All @@ -442,6 +444,7 @@ def _wrap_evaluation_matrices(
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=None,
)

n_validation = 0 if eval_set is None else len(eval_set)
Expand Down Expand Up @@ -491,6 +494,7 @@ def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence:
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=train_dmatrix,
)
evals.append(m)
nevals = len(evals)
Expand Down Expand Up @@ -904,6 +908,17 @@ def _duplicated(parameter: str) -> None:

return model, metric, params, early_stopping_rounds, callbacks

def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if self.tree_method in ("hist", "gpu_hist"):
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
)
except TypeError: # `QuantileDMatrix` supports lesser types than DMatrix
pass
return DMatrix(**kwargs, nthread=self.n_jobs)

def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result:
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
Expand Down Expand Up @@ -996,7 +1011,7 @@ def fit(
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
Expand Down Expand Up @@ -1479,7 +1494,7 @@ def fit(
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
Expand Down Expand Up @@ -1930,7 +1945,7 @@ def fit(
base_margin_eval_set=base_margin_eval_set,
eval_group=eval_group,
eval_qid=eval_qid,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
Expand Down
40 changes: 40 additions & 0 deletions src/data/iterative_dmatrix.cc
Expand Up @@ -14,6 +14,45 @@
namespace xgboost {
namespace data {

IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread,
bst_bin_t max_bin)
: proxy_{proxy}, reset_{reset}, next_{next} {
// fetch the first batch
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
iter.Reset();
bool valid = iter.Next();
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";

auto d = MakeProxy(proxy_)->DeviceIdx();

StringView msg{"All batch should be on the same device."};
if (batch_param_.gpu_id != Context::kCpuId) {
CHECK_EQ(d, batch_param_.gpu_id) << msg;
}

int32_t max_device{d};
rabit::Allreduce<rabit::op::Max>(&max_device, 1);
if (max_device != d) {
CHECK_EQ(MakeProxy(proxy_)->Info().num_row_, 0);
CHECK_NE(d, Context::kCpuId) << msg;
d = max_device;
}

batch_param_ = BatchParam{d, max_bin};
batch_param_.sparse_thresh = 0.2; // default from TrainParam

ctx_.UpdateAllowUnknown(
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref);
} else {
this->InitFromCUDA(iter_handle, missing, ref);
}
}

void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts* p_cuts) {
CHECK(ref_);
Expand Down Expand Up @@ -199,6 +238,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
if (n_batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
this->info_.num_col_ = n_features; // proxy might be empty.
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/data/iterative_dmatrix.cu
Expand Up @@ -173,8 +173,15 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& para
}
if (!ellpack_ && ghist_) {
ellpack_.reset(new EllpackPage());
this->ctx_.gpu_id = param.gpu_id;
this->Info().feature_types.SetDevice(param.gpu_id);
// Evaluation QuantileDMatrix initialized from CPU data might not have the correct GPU
// ID.
if (this->ctx_.IsCPU()) {
this->ctx_.gpu_id = param.gpu_id;
}
if (this->ctx_.IsCPU()) {
this->ctx_.gpu_id = dh::CurrentDevice();
}
this->Info().feature_types.SetDevice(this->ctx_.gpu_id);
*ellpack_->Impl() =
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
}
Expand Down
25 changes: 1 addition & 24 deletions src/data/iterative_dmatrix.h
Expand Up @@ -75,30 +75,7 @@ class IterativeDMatrix : public DMatrix {
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
bst_bin_t max_bin)
: proxy_{proxy}, reset_{reset}, next_{next} {
// fetch the first batch
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
iter.Reset();
bool valid = iter.Next();
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";

auto d = MakeProxy(proxy_)->DeviceIdx();
if (batch_param_.gpu_id != Context::kCpuId) {
CHECK_EQ(d, batch_param_.gpu_id) << "All batch should be on the same device.";
}
batch_param_ = BatchParam{d, max_bin};
batch_param_.sparse_thresh = 0.2; // default from TrainParam

ctx_.UpdateAllowUnknown(
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref);
} else {
this->InitFromCUDA(iter_handle, missing, ref);
}
}
bst_bin_t max_bin);
~IterativeDMatrix() override = default;

bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
Expand Down

0 comments on commit cff8513

Please sign in to comment.