From 5b1161bb64424ee4fc240255dbdaf547bb05975e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 Dec 2021 00:58:35 +0800 Subject: [PATCH] Convert labels into tensor. (#7456) * Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member. --- include/xgboost/data.h | 8 +-- include/xgboost/linalg.h | 46 +++++++++---- plugin/example/custom_obj.cc | 6 +- src/common/device_helpers.cuh | 10 +++ src/data/data.cc | 65 +++++++++++------- src/data/data.cu | 13 ++-- src/data/iterative_device_dmatrix.cu | 2 +- src/data/simple_dmatrix.cc | 12 ++-- src/metric/auc.cc | 73 +++++++++++---------- src/metric/auc.cu | 46 ++++++------- src/metric/elementwise_metric.cu | 4 +- src/metric/multiclass_metric.cu | 11 ++-- src/metric/rank_metric.cc | 20 +++--- src/metric/rank_metric.cu | 6 +- src/objective/hinge.cu | 8 +-- src/objective/multiclass_obj.cu | 10 +-- src/objective/rank_obj.cu | 22 +++---- src/objective/regression_obj.cu | 35 +++++----- tests/cpp/c_api/test_c_api.cc | 4 +- tests/cpp/data/test_metainfo.cc | 31 ++++----- tests/cpp/data/test_metainfo.cu | 14 ++-- tests/cpp/data/test_metainfo.h | 43 ++++++------ tests/cpp/data/test_proxy_dmatrix.cu | 2 +- tests/cpp/data/test_simple_dmatrix.cc | 4 +- tests/cpp/data/test_sparse_page_dmatrix.cc | 2 +- tests/cpp/gbm/test_gbtree.cc | 4 +- tests/cpp/helpers.cc | 23 ++++--- tests/cpp/metric/test_auc.cc | 4 +- tests/cpp/metric/test_elementwise_metric.cc | 2 +- tests/cpp/metric/test_multiclass_metric.cc | 5 +- tests/cpp/objective/test_regression_obj.cc | 6 +- tests/cpp/predictor/test_predictor.cc | 4 +- tests/cpp/test_learner.cc | 14 ++-- tests/cpp/test_serialization.cc | 15 +++-- tests/python/test_with_pandas.py | 3 +- 35 files changed, 319 insertions(+), 258 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 8cafbf02813e..2d1ed7f125b4 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -56,7 +56,7 @@ class MetaInfo { /*! \brief number of nonzero entries in the data */ uint64_t num_nonzero_{0}; // NOLINT /*! \brief label of each instance */ - HostDeviceVector labels_; // NOLINT + linalg::Tensor labels; /*! * \brief the index of begin and end of a group * needed when the learning task is ranking. @@ -119,12 +119,12 @@ class MetaInfo { } /*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */ inline const std::vector& LabelAbsSort() const { - if (label_order_cache_.size() == labels_.Size()) { + if (label_order_cache_.size() == labels.Size()) { return label_order_cache_; } - label_order_cache_.resize(labels_.Size()); + label_order_cache_.resize(labels.Size()); std::iota(label_order_cache_.begin(), label_order_cache_.end(), 0); - const auto& l = labels_.HostVector(); + const auto& l = labels.Data()->HostVector(); XGBOOST_PARALLEL_SORT(label_order_cache_.begin(), label_order_cache_.end(), [&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);}); diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 200610367616..78a18c044f61 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -635,6 +635,20 @@ class Tensor { HostDeviceVector data_; ShapeT shape_{0}; + template + void Initialize(I const (&shape)[D], std::int32_t device) { + static_assert(D <= kDim, "Invalid shape."); + std::copy(shape, shape + D, shape_); + for (auto i = D; i < kDim; ++i) { + shape_[i] = 1; + } + if (device >= 0) { + data_.SetDevice(device); + data_.DevicePointer(); // Pull to device; + } + CHECK_EQ(data_.Size(), detail::CalcSize(shape_)); + } + public: Tensor() = default; @@ -665,20 +679,20 @@ class Tensor { */ template explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) { - // shape - static_assert(D <= kDim, "Invalid shape."); - std::copy(shape, shape + D, shape_); - for (auto i = D; i < kDim; ++i) { - shape_[i] = 1; - } auto &h_vec = data_.HostVector(); h_vec.insert(h_vec.begin(), begin, end); - if (device >= 0) { - data_.SetDevice(device); - data_.DevicePointer(); // Pull to device; - } - CHECK_EQ(data_.Size(), detail::CalcSize(shape_)); + // shape + this->Initialize(shape, device); } + + template + explicit Tensor(std::initializer_list data, I const (&shape)[D], int32_t device) { + auto &h_vec = data_.HostVector(); + h_vec = data; + // shape + this->Initialize(shape, device); + } + /** * \brief Get a \ref TensorView for this tensor. */ @@ -703,6 +717,9 @@ class Tensor { } } + auto HostView() const { return this->View(-1); } + auto HostView() { return this->View(-1); } + size_t Size() const { return data_.Size(); } auto Shape() const { return common::Span{shape_}; } auto Shape(size_t i) const { return shape_[i]; } @@ -756,14 +773,15 @@ class Tensor { /** * \brief Set device ordinal for this tensor. */ - void SetDevice(int32_t device) { data_.SetDevice(device); } + void SetDevice(int32_t device) const { data_.SetDevice(device); } + int32_t DeviceIdx() const { return data_.DeviceIdx(); } }; // Only first axis is supported for now. template void Stack(Tensor *l, Tensor const &r) { - if (r.Data()->DeviceIdx() >= 0) { - l->Data()->SetDevice(r.Data()->DeviceIdx()); + if (r.DeviceIdx() >= 0) { + l->SetDevice(r.DeviceIdx()); } l->ModifyInplace([&](HostDeviceVector *data, common::Span shape) { for (size_t i = 1; i < D; ++i) { diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index c38ad4fbd82d..b61073360e00 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -46,15 +46,15 @@ class MyLogistic : public ObjFunction { out_gpair->Resize(preds.Size()); const std::vector& preds_h = preds.HostVector(); std::vector& out_gpair_h = out_gpair->HostVector(); - const std::vector& labels_h = info.labels_.HostVector(); + auto const labels_h = info.labels.HostView(); for (size_t i = 0; i < preds_h.size(); ++i) { bst_float w = info.GetWeight(i); // scale the negative examples! - if (labels_h[i] == 0.0f) w *= param_.scale_neg_weight; + if (labels_h(i) == 0.0f) w *= param_.scale_neg_weight; // logistic transformation bst_float p = 1.0f / (1.0f + std::exp(-preds_h[i])); // this is the gradient - bst_float grad = (p - labels_h[i]) * w; + bst_float grad = (p - labels_h(i)) * w; // this is the second order gradient bst_float hess = p * (1.0f - p) * w; out_gpair_h.at(i) = GradientPair(grad, hess); diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index b316453f5d2c..c74718554bed 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -956,11 +956,21 @@ thrust::device_ptr tbegin(xgboost::common::Span& span) { // NOLINT return thrust::device_ptr(span.data()); } +template +thrust::device_ptr tbegin(xgboost::common::Span const& span) { // NOLINT + return thrust::device_ptr(span.data()); +} + template thrust::device_ptr tend(xgboost::common::Span& span) { // NOLINT return tbegin(span) + span.size(); } +template +thrust::device_ptr tend(xgboost::common::Span const& span) { // NOLINT + return tbegin(span) + span.size(); +} + template XGBOOST_DEVICE auto trbegin(xgboost::common::Span &span) { // NOLINT return thrust::make_reverse_iterator(span.data() + span.size()); diff --git a/src/data/data.cc b/src/data/data.cc index fd3f2b6db207..fa5b388eae24 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -176,7 +176,7 @@ uint64_t constexpr MetaInfo::kNumField; // implementation of inline functions void MetaInfo::Clear() { num_row_ = num_col_ = num_nonzero_ = 0; - labels_.HostVector().clear(); + labels = decltype(labels){}; group_ptr_.clear(); weights_.HostVector().clear(); base_margin_ = decltype(base_margin_){}; @@ -213,8 +213,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { SaveScalarField(fo, u8"num_row", DataType::kUInt64, num_row_); ++field_cnt; SaveScalarField(fo, u8"num_col", DataType::kUInt64, num_col_); ++field_cnt; SaveScalarField(fo, u8"num_nonzero", DataType::kUInt64, num_nonzero_); ++field_cnt; - SaveVectorField(fo, u8"labels", DataType::kFloat32, - {labels_.Size(), 1}, labels_); ++field_cnt; + SaveTensorField(fo, u8"labels", DataType::kFloat32, labels); ++field_cnt; SaveVectorField(fo, u8"group_ptr", DataType::kUInt32, {group_ptr_.size(), 1}, group_ptr_); ++field_cnt; SaveVectorField(fo, u8"weights", DataType::kFloat32, @@ -291,7 +290,7 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadScalarField(fi, u8"num_row", DataType::kUInt64, &num_row_); LoadScalarField(fi, u8"num_col", DataType::kUInt64, &num_col_); LoadScalarField(fi, u8"num_nonzero", DataType::kUInt64, &num_nonzero_); - LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_); + LoadTensorField(fi, u8"labels", DataType::kFloat32, &labels); LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); @@ -326,7 +325,19 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { out.num_col_ = this->num_col_; // Groups is maintained by a higher level Python function. We should aim at deprecating // the slice function. - out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs); + if (this->labels.Size() != this->num_row_) { + auto t_labels = this->labels.View(this->labels.Data()->DeviceIdx()); + out.labels.Reshape(ridxs.size(), labels.Shape(1)); + out.labels.Data()->HostVector() = + Gather(this->labels.Data()->HostVector(), ridxs, t_labels.Stride(0)); + } else { + out.labels.ModifyInplace([&](auto* data, common::Span shape) { + data->HostVector() = Gather(this->labels.Data()->HostVector(), ridxs); + shape[0] = data->Size(); + shape[1] = 1; + }); + } + out.labels_upper_bound_.HostVector() = Gather(this->labels_upper_bound_.HostVector(), ridxs); out.labels_lower_bound_.HostVector() = @@ -343,13 +354,16 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; - auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); - out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1]); - size_t stride = margin.Stride(0); + auto t_margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); + out.base_margin_.Reshape(ridxs.size(), t_margin.Shape(1)); out.base_margin_.Data()->HostVector() = - Gather(this->base_margin_.Data()->HostVector(), ridxs, stride); + Gather(this->base_margin_.Data()->HostVector(), ridxs, t_margin.Stride(0)); } else { - out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs); + out.base_margin_.ModifyInplace([&](auto* data, common::Span shape) { + data->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs); + shape[0] = data->Size(); + shape[1] = 1; + }); } out.feature_weights.Resize(this->feature_weights.Size()); @@ -460,6 +474,17 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { this->base_margin_.Reshape(this->num_row_, n_groups); } return; + } else if (key == "label") { + CopyTensorInfoImpl(arr, &this->labels); + if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) { + CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels."; + size_t n_targets = this->labels.Size() / this->num_row_; + this->labels.Reshape(this->num_row_, n_targets); + } + auto const& h_labels = labels.Data()->ConstHostVector(); + auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; + return; } // uint info if (key == "group") { @@ -500,12 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) { // float info linalg::Tensor t; CopyTensorInfoImpl<1>(arr, &t); - if (key == "label") { - this->labels_ = std::move(*t.Data()); - auto const& h_labels = labels_.ConstHostVector(); - auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); - CHECK(valid) << "Label contains NaN, infinity or a value too large."; - } else if (key == "weight") { + if (key == "weight") { this->weights_ = std::move(*t.Data()); auto const& h_weights = this->weights_.ConstHostVector(); auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(), @@ -568,7 +588,7 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, if (dtype == DataType::kFloat32) { const std::vector* vec = nullptr; if (!std::strcmp(key, "label")) { - vec = &this->labels_.HostVector(); + vec = &this->labels.Data()->HostVector(); } else if (!std::strcmp(key, "weight")) { vec = &this->weights_.HostVector(); } else if (!std::strcmp(key, "base_margin")) { @@ -649,8 +669,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } this->num_col_ = that.num_col_; - this->labels_.SetDevice(that.labels_.DeviceIdx()); - this->labels_.Extend(that.labels_); + linalg::Stack(&this->labels, that.labels); this->weights_.SetDevice(that.weights_.DeviceIdx()); this->weights_.Extend(that.weights_); @@ -702,7 +721,7 @@ void MetaInfo::Validate(int32_t device) const { << "Invalid group structure. Number of rows obtained from groups " "doesn't equal to actual number of rows given by data."; } - auto check_device = [device](HostDeviceVector const &v) { + auto check_device = [device](HostDeviceVector const& v) { CHECK(v.DeviceIdx() == GenericParameter::kCpuId || device == GenericParameter::kCpuId || v.DeviceIdx() == device) @@ -717,10 +736,10 @@ void MetaInfo::Validate(int32_t device) const { check_device(weights_); return; } - if (labels_.Size() != 0) { - CHECK_EQ(labels_.Size(), num_row_) + if (labels.Size() != 0) { + CHECK_EQ(labels.Size(), num_row_) << "Size of labels must equal to number of rows."; - check_device(labels_); + check_device(*labels.Data()); return; } if (labels_lower_bound_.Size() != 0) { diff --git a/src/data/data.cu b/src/data/data.cu index 6d85a85e261b..aada91a62e2a 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -119,6 +119,12 @@ void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { if (key == "base_margin") { CopyTensorInfoImpl(array, &base_margin_); return; + } else if (key == "label") { + CopyTensorInfoImpl(array, &labels); + auto ptr = labels.Data()->ConstDevicePointer(); + auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; + return; } // uint info if (key == "group") { @@ -135,12 +141,7 @@ void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { // float info linalg::Tensor t; CopyTensorInfoImpl(array, &t); - if (key == "label") { - this->labels_ = std::move(*t.Data()); - auto ptr = labels_.ConstDevicePointer(); - auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), data::LabelsCheck{}); - CHECK(valid) << "Label contains NaN, infinity or a value too large."; - } else if (key == "weight") { + if (key == "weight") { this->weights_ = std::move(*t.Data()); auto ptr = weights_.ConstDevicePointer(); auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), data::WeightsCheck{}); diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index d3869eff1c45..0f7b6d790492 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -153,7 +153,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin if (batches == 1) { this->info_ = std::move(proxy->Info()); this->info_.num_nonzero_ = nnz; - CHECK_EQ(proxy->Info().labels_.Size(), 0); + CHECK_EQ(proxy->Info().labels.Size(), 0); } iter.Reset(); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index ce2e262c113f..66a5c0d3ed0e 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -127,14 +127,16 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { total_batch_size += batch.Size(); // Append meta information if available if (batch.Labels() != nullptr) { - auto& labels = info_.labels_.HostVector(); - labels.insert(labels.end(), batch.Labels(), - batch.Labels() + batch.Size()); + info_.labels.ModifyInplace([&](auto* data, common::Span shape) { + shape[1] = 1; + auto& labels = data->HostVector(); + labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size()); + shape[0] += batch.Size(); + }); } if (batch.Weights() != nullptr) { auto& weights = info_.weights_.HostVector(); - weights.insert(weights.end(), batch.Weights(), - batch.Weights() + batch.Size()); + weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(), diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 5097116fbb78..1957bcc9a083 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -32,17 +32,16 @@ namespace metric { */ template std::tuple -BinaryAUC(common::Span predts, common::Span labels, +BinaryAUC(common::Span predts, linalg::VectorView labels, OptionalWeights weights, std::vector const &sorted_idx, Fn &&area_fn) { - CHECK(!labels.empty()); - CHECK_EQ(labels.size(), predts.size()); + CHECK_NE(labels.Size(), 0); + CHECK_EQ(labels.Size(), predts.size()); auto p_predts = predts.data(); - auto p_labels = labels.data(); double auc{0}; - float label = p_labels[sorted_idx.front()]; + float label = labels(sorted_idx.front()); float w = weights[sorted_idx[0]]; double fp = (1.0 - label) * w, tp = label * w; double tp_prev = 0, fp_prev = 0; @@ -53,7 +52,7 @@ BinaryAUC(common::Span predts, common::Span labels, tp_prev = tp; fp_prev = fp; } - label = p_labels[sorted_idx[i]]; + label = labels(sorted_idx[i]); float w = weights[sorted_idx[i]]; fp += (1.0f - label) * w; tp += label * w; @@ -82,7 +81,10 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) { CHECK_NE(n_classes, 0); - auto const &labels = info.labels_.ConstHostVector(); + auto const labels = info.labels.View(GenericParameter::kCpuId); + if (labels.Shape(0) != 0) { + CHECK_EQ(labels.Shape(1), 1) << "AUC doesn't support multi-target model."; + } std::vector results_storage(n_classes * 3, 0); linalg::TensorView results(results_storage, {n_classes, static_cast(3)}, @@ -96,16 +98,17 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, predts, {static_cast(info.num_row_), n_classes}, GenericParameter::kCpuId); - if (!info.labels_.Empty()) { + if (info.labels.Size() != 0) { common::ParallelFor(n_classes, n_threads, [&](auto c) { - std::vector proba(info.labels_.Size()); - std::vector response(info.labels_.Size()); + std::vector proba(info.labels.Size()); + std::vector response(info.labels.Size()); for (size_t i = 0; i < proba.size(); ++i) { proba[i] = predts_t(i, c); - response[i] = labels[i] == c ? 1.0f : 0.0; + response[i] = labels(i) == c ? 1.0f : 0.0; } double fp; - std::tie(fp, tp(c), auc(c)) = binary_auc(proba, response, weights); + std::tie(fp, tp(c), auc(c)) = + binary_auc(proba, linalg::MakeVec(response.data(), response.size(), -1), weights); local_area(c) = fp * tp(c); }); } @@ -135,9 +138,9 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, return auc_sum; } -std::tuple -BinaryROCAUC(common::Span predts, common::Span labels, - OptionalWeights weights) { +std::tuple BinaryROCAUC(common::Span predts, + linalg::VectorView labels, + OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea); } @@ -146,15 +149,17 @@ BinaryROCAUC(common::Span predts, common::Span labels, * Calculate AUC for 1 ranking group; */ double GroupRankingROC(common::Span predts, - common::Span labels, float w) { + linalg::VectorView labels, float w) { // on ranking, we just count all pairs. double auc{0}; - auto const sorted_idx = common::ArgSort(labels, std::greater<>{}); + // argsort doesn't support tensor input yet. + auto raw_labels = labels.Values().subspan(0, labels.Size()); + auto const sorted_idx = common::ArgSort(raw_labels, std::greater<>{}); w = common::Sqr(w); double sum_w = 0.0f; - for (size_t i = 0; i < labels.size(); ++i) { - for (size_t j = i + 1; j < labels.size(); ++j) { + for (size_t i = 0; i < labels.Size(); ++i) { + for (size_t j = i + 1; j < labels.Size(); ++j) { auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]]; if (predt > 0) { predt = 1.0; @@ -180,14 +185,14 @@ double GroupRankingROC(common::Span predts, * https://doi.org/10.1371/journal.pone.0092209 */ std::tuple BinaryPRAUC(common::Span predts, - common::Span labels, + linalg::VectorView labels, OptionalWeights weights) { auto const sorted_idx = common::ArgSort(predts, std::greater<>{}); double total_pos{0}, total_neg{0}; - for (size_t i = 0; i < labels.size(); ++i) { + for (size_t i = 0; i < labels.Size(); ++i) { auto w = weights[i]; - total_pos += w * labels[i]; - total_neg += w * (1.0f - labels[i]); + total_pos += w * labels(i); + total_neg += w * (1.0f - labels(i)); } if (total_pos <= 0 || total_neg <= 0) { return {1.0f, 1.0f, std::numeric_limits::quiet_NaN()}; @@ -211,7 +216,7 @@ std::pair RankingAUC(std::vector const &predts, CHECK_GE(info.group_ptr_.size(), 2); uint32_t n_groups = info.group_ptr_.size() - 1; auto s_predts = common::Span{predts}; - auto s_labels = info.labels_.ConstHostSpan(); + auto labels = info.labels.View(GenericParameter::kCpuId); auto s_weights = info.weights_.ConstHostSpan(); std::atomic invalid_groups{0}; @@ -222,9 +227,9 @@ std::pair RankingAUC(std::vector const &predts, size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1]; float w = s_weights.empty() ? 1.0f : s_weights[g - 1]; auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt); - auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt); + auto g_labels = labels.Slice(linalg::Range(info.group_ptr_[g - 1], info.group_ptr_[g])); double auc; - if (is_roc && g_labels.size() < 3) { + if (is_roc && g_labels.Size() < 3) { // With 2 documents, there's only 1 comparison can be made. So either // TP or FP will be zero. invalid_groups++; @@ -254,11 +259,11 @@ class EvalAUC : public Metric { double auc {0}; if (tparam_->gpu_id != GenericParameter::kCpuId) { preds.SetDevice(tparam_->gpu_id); - info.labels_.SetDevice(tparam_->gpu_id); + info.labels.SetDevice(tparam_->gpu_id); info.weights_.SetDevice(tparam_->gpu_id); } // We use the global size to handle empty dataset. - std::array meta{info.labels_.Size(), preds.Size()}; + std::array meta{info.labels.Size(), preds.Size()}; rabit::Allreduce(meta.data(), meta.size()); if (meta[0] == 0) { // Empty across all workers, which is not supported. @@ -271,8 +276,8 @@ class EvalAUC : public Metric { CHECK_EQ(info.weights_.Size(), info.group_ptr_.size() - 1); } uint32_t valid_groups = 0; - if (!info.labels_.Empty()) { - CHECK_EQ(info.group_ptr_.back(), info.labels_.Size()); + if (info.labels.Size() != 0) { + CHECK_EQ(info.group_ptr_.back(), info.labels.Size()); std::tie(auc, valid_groups) = static_cast(this)->EvalRanking(preds, info); } @@ -304,7 +309,7 @@ class EvalAUC : public Metric { * binary classification */ double fp{0}, tp{0}; - if (!(preds.Empty() || info.labels_.Empty())) { + if (!(preds.Empty() || info.labels.Size() == 0)) { std::tie(fp, tp, auc) = static_cast(this)->EvalBinary(preds, info); } @@ -367,7 +372,7 @@ class EvalROCAUC : public EvalAUC { double fp, tp, auc; if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(fp, tp, auc) = - BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(), + BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0), OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, @@ -420,7 +425,7 @@ class EvalPRAUC : public EvalAUC { double pr, re, auc; if (tparam_->gpu_id == GenericParameter::kCpuId) { std::tie(pr, re, auc) = - BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(), + BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, @@ -447,7 +452,7 @@ class EvalPRAUC : public EvalAUC { uint32_t valid_groups = 0; auto n_threads = tparam_->Threads(); if (tparam_->gpu_id == GenericParameter::kCpuId) { - auto labels = info.labels_.ConstHostSpan(); + auto labels = info.labels.Data()->ConstHostSpan(); if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) { InvalidLabels(); } diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 153a0290afcc..317ce7db2c84 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -89,12 +89,12 @@ std::tuple GPUBinaryAUC(common::Span predts, MetaInfo const &info, int32_t device, common::Span d_sorted_idx, Fn area_fn, std::shared_ptr cache) { - auto labels = info.labels_.ConstDeviceSpan(); + auto labels = info.labels.View(device); auto weights = info.weights_.ConstDeviceSpan(); dh::safe_cuda(cudaSetDevice(device)); - CHECK(!labels.empty()); - CHECK_EQ(labels.size(), predts.size()); + CHECK_NE(labels.Size(), 0); + CHECK_EQ(labels.Size(), predts.size()); /** * Linear scan @@ -103,7 +103,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) { size_t idx = d_sorted_idx[i]; - float label = labels[idx]; + float label = labels(idx); float w = get_weight[d_sorted_idx[i]]; float fp = (1.0 - label) * w; @@ -332,10 +332,10 @@ double GPUMultiClassAUCOVR(common::Span predts, // Index is sorted within class. auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - auto labels = info.labels_.ConstDeviceSpan(); + auto labels = info.labels.View(device); auto weights = info.weights_.ConstDeviceSpan(); - size_t n_samples = labels.size(); + size_t n_samples = labels.Shape(0); if (n_samples == 0) { dh::TemporaryArray resutls(n_classes * 4, 0.0f); @@ -360,7 +360,7 @@ double GPUMultiClassAUCOVR(common::Span predts, size_t class_id = i / n_samples; // labels is a vector of size n_samples. - float label = labels[idx % n_samples] == class_id; + float label = labels(idx % n_samples) == class_id; float w = get_weight[d_sorted_idx[i] % n_samples]; float fp = (1.0 - label) * w; @@ -528,10 +528,10 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, /** * Sort the labels */ - auto d_labels = info.labels_.ConstDeviceSpan(); + auto d_labels = info.labels.View(device); auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - dh::SegmentedArgSort(d_labels, d_group_ptr, d_sorted_idx); + dh::SegmentedArgSort(d_labels.Values(), d_group_ptr, d_sorted_idx); auto d_weights = info.weights_.ConstDeviceSpan(); @@ -631,19 +631,19 @@ GPUBinaryPRAUC(common::Span predts, MetaInfo const &info, auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); dh::ArgSort(predts, d_sorted_idx); - auto labels = info.labels_.ConstDeviceSpan(); + auto labels = info.labels.View(device); auto d_weights = info.weights_.ConstDeviceSpan(); auto get_weight = OptionalWeights{d_weights}; auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { auto w = get_weight[d_sorted_idx[i]]; - return thrust::make_pair(labels[d_sorted_idx[i]] * w, - (1.0f - labels[d_sorted_idx[i]]) * w); + return thrust::make_pair(labels(d_sorted_idx[i]) * w, + (1.0f - labels(d_sorted_idx[i])) * w); }); dh::XGBCachingDeviceAllocator alloc; double total_pos, total_neg; thrust::tie(total_pos, total_neg) = - thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(), + thrust::reduce(thrust::cuda::par(alloc), it, it + labels.Size(), Pair{0.0, 0.0}, PairPlus{}); if (total_pos <= 0.0 || total_neg <= 0.0) { @@ -679,7 +679,7 @@ double GPUMultiClassPRAUC(common::Span predts, /** * Get total positive/negative */ - auto labels = info.labels_.ConstDeviceSpan(); + auto labels = info.labels.View(device); auto n_samples = info.num_row_; dh::caching_device_vector totals(n_classes); auto key_it = @@ -693,7 +693,7 @@ double GPUMultiClassPRAUC(common::Span predts, auto idx = d_sorted_idx[i] % n_samples; auto w = get_weight[idx]; auto class_id = i / n_samples; - auto y = labels[idx] == class_id; + auto y = labels(idx) == class_id; return thrust::make_pair(y * w, (1.0f - y) * w); }); dh::XGBCachingDeviceAllocator alloc; @@ -726,7 +726,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, */ auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - auto labels = info.labels_.ConstDeviceSpan(); + auto labels = info.labels.View(device); auto weights = info.weights_.ConstDeviceSpan(); uint32_t n_groups = static_cast(info.group_ptr_.size() - 1); @@ -734,7 +734,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, /** * Linear scan */ - size_t n_samples = labels.size(); + size_t n_samples = labels.Shape(0); dh::caching_device_vector d_auc(n_groups, 0); auto get_weight = OptionalWeights{weights}; auto d_fptp = dh::ToSpan(cache->fptp); @@ -742,7 +742,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, size_t idx = d_sorted_idx[i]; size_t group_id = dh::SegmentId(d_group_ptr, idx); - float label = labels[idx]; + float label = labels(idx); float w = get_weight[group_id]; float fp = (1.0 - label) * w; @@ -860,9 +860,9 @@ GPURankingPRAUC(common::Span predts, MetaInfo const &info, dh::SegmentedArgSort(predts, d_group_ptr, d_sorted_idx); dh::XGBDeviceAllocator alloc; - auto labels = info.labels_.ConstDeviceSpan(); - if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels), - dh::tend(labels), PRAUCLabelInvalid{})) { + auto labels = info.labels.View(device); + if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()), + dh::tend(labels.Values()), PRAUCLabelInvalid{})) { InvalidLabels(); } /** @@ -881,7 +881,7 @@ GPURankingPRAUC(common::Span predts, MetaInfo const &info, auto g = dh::SegmentId(d_group_ptr, i); w = d_weights[g]; } - auto y = labels[i]; + auto y = labels(i); return thrust::make_pair(y * w, (1.0 - y) * w); }); thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, @@ -899,7 +899,7 @@ GPURankingPRAUC(common::Span predts, MetaInfo const &info, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[group_id].first); }; - return GPURankingPRAUCImpl(predts, info, d_group_ptr, n_groups, cache, fn); + return GPURankingPRAUCImpl(predts, info, d_group_ptr, device, cache, fn); } } // namespace metric } // namespace xgboost diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index ddc95576870d..9dc84da9853b 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -361,10 +361,10 @@ struct EvalEWiseBase : public Metric { double Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { - CHECK_EQ(preds.Size(), info.labels_.Size()) + CHECK_EQ(preds.Size(), info.labels.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - auto result = reducer_.Reduce(*tparam_, info.weights_, info.labels_, preds); + auto result = reducer_.Reduce(*tparam_, info.weights_, *info.labels.Data(), preds); double dat[2] { result.Residue(), result.Weights() }; diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 3a42c46e7453..9ba8412e11c8 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -169,19 +169,20 @@ template struct EvalMClassBase : public Metric { double Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { - if (info.labels_.Size() == 0) { + if (info.labels.Size() == 0) { CHECK_EQ(preds.Size(), 0); } else { - CHECK(preds.Size() % info.labels_.Size() == 0) << "label and prediction size not match"; + CHECK(preds.Size() % info.labels.Size() == 0) << "label and prediction size not match"; } double dat[2] { 0.0, 0.0 }; - if (info.labels_.Size() != 0) { - const size_t nclass = preds.Size() / info.labels_.Size(); + if (info.labels.Size() != 0) { + const size_t nclass = preds.Size() / info.labels.Size(); CHECK_GE(nclass, 1U) << "mlogloss and merror are only used for multi-class classification," << " use logloss for binary classification"; int device = tparam_->gpu_id; - auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds); + auto result = + reducer_.Reduce(*tparam_, device, nclass, info.weights_, *info.labels.Data(), preds); dat[0] = result.Residue(); dat[1] = result.Weights(); } diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index f57d1392662d..1aa0c4cb0f23 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -107,7 +107,7 @@ struct EvalAMS : public Metric { CHECK(!distributed) << "metric AMS do not support distributed evaluation"; using namespace std; // NOLINT(*) - const auto ndata = static_cast(info.labels_.Size()); + const auto ndata = static_cast(info.labels.Size()); PredIndPairContainer rec(ndata); const auto &h_preds = preds.ConstHostVector(); @@ -120,11 +120,11 @@ struct EvalAMS : public Metric { const double br = 10.0; unsigned thresindex = 0; double s_tp = 0.0, b_fp = 0.0, tams = 0.0; - const auto& labels = info.labels_.ConstHostVector(); + const auto& labels = info.labels.View(GenericParameter::kCpuId); for (unsigned i = 0; i < static_cast(ndata-1) && i < ntop; ++i) { const unsigned ridx = rec[i].second; const bst_float wt = info.GetWeight(ridx); - if (labels[ridx] > 0.5f) { + if (labels(ridx) > 0.5f) { s_tp += wt; } else { b_fp += wt; @@ -164,7 +164,7 @@ struct EvalRank : public Metric, public EvalRankConfig { public: double Eval(const HostDeviceVector &preds, const MetaInfo &info, bool distributed) override { - CHECK_EQ(preds.Size(), info.labels_.Size()) + CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match"; // quick consistency when group is not available @@ -194,7 +194,7 @@ struct EvalRank : public Metric, public EvalRankConfig { std::vector sum_tloc(tparam_->Threads(), 0.0); if (!rank_gpu_ || tparam_->gpu_id < 0) { - const auto &labels = info.labels_.ConstHostVector(); + const auto& labels = info.labels.View(GenericParameter::kCpuId); const auto &h_preds = preds.ConstHostVector(); dmlc::OMPException exc; @@ -208,7 +208,7 @@ struct EvalRank : public Metric, public EvalRankConfig { exc.Run([&]() { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - rec.emplace_back(h_preds[j], static_cast(labels[j])); + rec.emplace_back(h_preds[j], static_cast(labels(j))); } sum_tloc[omp_get_thread_num()] += this->EvalGroup(&rec); }); @@ -348,7 +348,7 @@ struct EvalCox : public Metric { CHECK(!distributed) << "Cox metric does not support distributed evaluation"; using namespace std; // NOLINT(*) - const auto ndata = static_cast(info.labels_.Size()); + const auto ndata = static_cast(info.labels.Size()); const auto &label_order = info.LabelAbsSort(); // pre-compute a sum for the denominator @@ -362,10 +362,10 @@ struct EvalCox : public Metric { double out = 0; double accumulated_sum = 0; bst_omp_uint num_events = 0; - const auto& labels = info.labels_.ConstHostVector(); + const auto& labels = info.labels.HostView(); for (bst_omp_uint i = 0; i < ndata; ++i) { const size_t ind = label_order[i]; - const auto label = labels[ind]; + const auto label = labels(ind); if (label > 0) { out -= log(h_preds[ind]) - log(exp_p_sum); ++num_events; @@ -373,7 +373,7 @@ struct EvalCox : public Metric { // only update the denominator after we move forward in time (labels are sorted) accumulated_sum += h_preds[ind]; - if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) { + if (i == ndata - 1 || std::abs(label) < std::abs(labels(label_order[i + 1]))) { exp_p_sum -= accumulated_sum; accumulated_sum = 0; } diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 0e7f9cc15324..36fca9482a96 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -41,18 +41,18 @@ struct EvalRankGpu : public Metric, public EvalRankConfig { auto device = tparam_->gpu_id; dh::safe_cuda(cudaSetDevice(device)); - info.labels_.SetDevice(device); + info.labels.SetDevice(device); preds.SetDevice(device); auto dpreds = preds.ConstDevicePointer(); - auto dlabels = info.labels_.ConstDevicePointer(); + auto dlabels = info.labels.View(device); // Sort all the predictions dh::SegmentSorter segment_pred_sorter; segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); // Compute individual group metric and sum them up - return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels, *this); + return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this); } const char* Name() const override { diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 068a4eea61ce..09b379804f4a 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -33,11 +33,11 @@ class HingeObj : public ObjFunction { const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) + CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided" << "preds.size=" << preds.Size() - << ", label.size=" << info.labels_.Size(); + << ", label.size=" << info.labels.Size(); const size_t ndata = preds.Size(); const bool is_null_weight = info.weights_.Size() == 0; @@ -67,7 +67,7 @@ class HingeObj : public ObjFunction { }, common::Range{0, static_cast(ndata)}, tparam_->gpu_id).Eval( - out_gpair, &preds, &info.labels_, &info.weights_); + out_gpair, &preds, info.labels.Data(), &info.weights_); } void PredTransform(HostDeviceVector *io_preds) const override { diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 710428b00008..a3f01b419743 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -55,13 +55,13 @@ class SoftmaxMultiClassObj : public ObjFunction { // Remove unused parameter compiler warning. (void) iter; - if (info.labels_.Size() == 0) { + if (info.labels.Size() == 0) { return; } - CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels_.Size())) + CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels.Size())) << "SoftmaxMultiClassObj: label size and pred size does not match.\n" << "label.Size() * num_class: " - << info.labels_.Size() * static_cast(param_.num_class) << "\n" + << info.labels.Size() * static_cast(param_.num_class) << "\n" << "num_class: " << param_.num_class << "\n" << "preds.Size(): " << preds.Size(); @@ -70,7 +70,7 @@ class SoftmaxMultiClassObj : public ObjFunction { auto device = tparam_->gpu_id; out_gpair->SetDevice(device); - info.labels_.SetDevice(device); + info.labels.SetDevice(device); info.weights_.SetDevice(device); preds.SetDevice(device); @@ -115,7 +115,7 @@ class SoftmaxMultiClassObj : public ObjFunction { gpair[idx * nclass + k] = GradientPair(p * wt, h); } }, common::Range{0, ndata}, device, false) - .Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_); + .Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_); std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 228c54642645..9f4d86aafb31 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -760,15 +760,15 @@ class LambdaRankObj : public ObjFunction { const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match"; // quick consistency when group is not available - std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.Size()); + std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size()) + CHECK(gptr.size() != 0 && gptr.back() == info.labels.Size()) << "group structure not consistent with #rows" << ", " << "group ponter size: " << gptr.size() << ", " - << "labels size: " << info.labels_.Size() << ", " + << "labels size: " << info.labels.Size() << ", " << "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back()); #if defined(__CUDACC__) @@ -820,7 +820,7 @@ class LambdaRankObj : public ObjFunction { bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); const auto& preds_h = preds.HostVector(); - const auto& labels = info.labels_.HostVector(); + const auto& labels = info.labels.HostView(); std::vector& gpair = out_gpair->HostVector(); const auto ngroup = static_cast(gptr.size() - 1); out_gpair->Resize(preds.Size()); @@ -841,7 +841,7 @@ class LambdaRankObj : public ObjFunction { exc.Run([&]() { lst.clear(); pairs.clear(); for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) { - lst.emplace_back(preds_h[j], labels[j], j); + lst.emplace_back(preds_h[j], labels(j), j); gpair[j] = GradientPair(0.0f, 0.0f); } std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred); @@ -916,7 +916,7 @@ class LambdaRankObj : public ObjFunction { // Set the device ID and copy them to the device out_gpair->SetDevice(device); - info.labels_.SetDevice(device); + info.labels.SetDevice(device); preds.SetDevice(device); info.weights_.SetDevice(device); @@ -924,19 +924,19 @@ class LambdaRankObj : public ObjFunction { auto d_preds = preds.ConstDevicePointer(); auto d_gpair = out_gpair->DevicePointer(); - auto d_labels = info.labels_.ConstDevicePointer(); + auto d_labels = info.labels.View(device); SortedLabelList slist(param_); // Sort the labels within the groups on the device - slist.Sort(info.labels_, gptr); + slist.Sort(*info.labels.Data(), gptr); // Initialize the gradients next out_gpair->Fill(GradientPair(0.0f, 0.0f)); // Finally, compute the gradients - slist.ComputeGradients - (d_preds, d_labels, info.weights_, iter, d_gpair, weight_normalization_factor); + slist.ComputeGradients(d_preds, d_labels.Values().data(), info.weights_, + iter, d_gpair, weight_normalization_factor); } #endif diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index ca9ec2c7029d..5dd1a82dd99c 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -59,9 +59,9 @@ class RegLossObj : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels_.Size()) + CHECK_EQ(preds.Size(), info.labels.Size()) << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", " + << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " << "Loss: " << Loss::Name(); size_t const ndata = preds.Size(); out_gpair->Resize(ndata); @@ -81,8 +81,7 @@ class RegLossObj : public ObjFunction { bool on_device = device >= 0; // On CPU we run the transformation each thread processing a contigious block of data // for better performance. - const size_t n_data_blocks = - std::max(static_cast(1), (on_device ? ndata : nthreads)); + const size_t n_data_blocks = std::max(static_cast(1), (on_device ? ndata : nthreads)); const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks); common::Transform<>::Init( [block_size, ndata] XGBOOST_DEVICE( @@ -116,7 +115,7 @@ class RegLossObj : public ObjFunction { } }, common::Range{0, static_cast(n_data_blocks)}, device) - .Eval(&additional_input_, out_gpair, &preds, &info.labels_, + .Eval(&additional_input_, out_gpair, &preds, info.labels.Data(), &info.weights_); auto const flag = additional_input_.HostVector().begin()[0]; @@ -218,8 +217,8 @@ class PoissonRegression : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; size_t const ndata = preds.Size(); out_gpair->Resize(ndata); auto device = tparam_->gpu_id; @@ -249,7 +248,7 @@ class PoissonRegression : public ObjFunction { expf(p + max_delta_step) * w}; }, common::Range{0, static_cast(ndata)}, device).Eval( - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { @@ -313,8 +312,8 @@ class CoxRegression : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const auto& preds_h = preds.HostVector(); out_gpair->Resize(preds_h.size()); auto& gpair = out_gpair->HostVector(); @@ -334,7 +333,7 @@ class CoxRegression : public ObjFunction { } // start calculating grad and hess - const auto& labels = info.labels_.HostVector(); + const auto& labels = info.labels.HostView(); double r_k = 0; double s_k = 0; double last_exp_p = 0.0; @@ -345,7 +344,7 @@ class CoxRegression : public ObjFunction { const double p = preds_h[ind]; const double exp_p = std::exp(p); const double w = info.GetWeight(ind); - const double y = labels[ind]; + const double y = labels(ind); const double abs_y = std::abs(y); // only update the denominator after we move forward in time (labels are sorted) @@ -414,8 +413,8 @@ class GammaRegression : public ObjFunction { void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const size_t ndata = preds.Size(); auto device = tparam_->gpu_id; out_gpair->Resize(ndata); @@ -443,7 +442,7 @@ class GammaRegression : public ObjFunction { _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); }, common::Range{0, static_cast(ndata)}, device).Eval( - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -514,8 +513,8 @@ class TweedieRegression : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const size_t ndata = preds.Size(); out_gpair->Resize(ndata); @@ -550,7 +549,7 @@ class TweedieRegression : public ObjFunction { _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, common::Range{0, static_cast(ndata), 1}, device) - .Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + .Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index d5f284fe0424..810d39710849 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -91,7 +91,7 @@ TEST(CAPI, ConfigIO) { for (size_t i = 0; i < labels.size(); ++i) { labels[i] = i; } - p_dmat->Info().labels_.HostVector() = labels; + p_dmat->Info().labels.Data()->HostVector() = labels; std::shared_ptr learner { Learner::Create(mat) }; @@ -125,7 +125,7 @@ TEST(CAPI, JsonModelIO) { for (size_t i = 0; i < labels.size(); ++i) { labels[i] = i; } - p_dmat->Info().labels_.HostVector() = labels; + p_dmat->Info().labels.Data()->HostVector() = labels; std::shared_ptr learner { Learner::Create(mat) }; diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index c8fbe43a4c17..2f17f6bfe376 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -16,9 +16,9 @@ TEST(MetaInfo, GetSet) { double double2[2] = {1.0, 2.0}; - EXPECT_EQ(info.labels_.Size(), 0); + EXPECT_EQ(info.labels.Size(), 0); info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2); - EXPECT_EQ(info.labels_.Size(), 2); + EXPECT_EQ(info.labels.Size(), 2); float float2[2] = {1.0f, 2.0f}; EXPECT_EQ(info.GetWeight(1), 1.0f) @@ -120,8 +120,8 @@ TEST(MetaInfo, SaveLoadBinary) { EXPECT_EQ(inforead.num_col_, info.num_col_); EXPECT_EQ(inforead.num_nonzero_, info.num_nonzero_); - ASSERT_EQ(inforead.labels_.HostVector(), values); - EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); + ASSERT_EQ(inforead.labels.Data()->HostVector(), values); + EXPECT_EQ(inforead.labels.Data()->HostVector(), info.labels.Data()->HostVector()); EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); @@ -236,8 +236,9 @@ TEST(MetaInfo, Validate) { EXPECT_THROW(info.Validate(0), dmlc::Error); std::vector labels(info.num_row_ + 1); - info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); - EXPECT_THROW(info.Validate(0), dmlc::Error); + EXPECT_THROW( + { info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); }, + dmlc::Error); // Make overflow data, which can happen when users pass group structure as int // or float. @@ -254,7 +255,7 @@ TEST(MetaInfo, Validate) { info.group_ptr_.clear(); labels.resize(info.num_row_); info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); - info.labels_.SetDevice(0); + info.labels.SetDevice(0); EXPECT_THROW(info.Validate(1), dmlc::Error); xgboost::HostDeviceVector d_groups{groups}; @@ -269,12 +270,12 @@ TEST(MetaInfo, Validate) { TEST(MetaInfo, HostExtend) { xgboost::MetaInfo lhs, rhs; size_t const kRows = 100; - lhs.labels_.Resize(kRows); + lhs.labels.Reshape(kRows); lhs.num_row_ = kRows; - rhs.labels_.Resize(kRows); + rhs.labels.Reshape(kRows); rhs.num_row_ = kRows; - ASSERT_TRUE(lhs.labels_.HostCanRead()); - ASSERT_TRUE(rhs.labels_.HostCanRead()); + ASSERT_TRUE(lhs.labels.Data()->HostCanRead()); + ASSERT_TRUE(rhs.labels.Data()->HostCanRead()); size_t per_group = 10; std::vector groups; @@ -286,10 +287,10 @@ TEST(MetaInfo, HostExtend) { lhs.Extend(rhs, true, true); ASSERT_EQ(lhs.num_row_, kRows * 2); - ASSERT_TRUE(lhs.labels_.HostCanRead()); - ASSERT_TRUE(rhs.labels_.HostCanRead()); - ASSERT_FALSE(lhs.labels_.DeviceCanRead()); - ASSERT_FALSE(rhs.labels_.DeviceCanRead()); + ASSERT_TRUE(lhs.labels.Data()->HostCanRead()); + ASSERT_TRUE(rhs.labels.Data()->HostCanRead()); + ASSERT_FALSE(lhs.labels.Data()->DeviceCanRead()); + ASSERT_FALSE(rhs.labels.Data()->DeviceCanRead()); ASSERT_EQ(lhs.group_ptr_.front(), 0); ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2); diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index bbb78e7924e7..c02597eef1fd 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -52,10 +52,10 @@ TEST(MetaInfo, FromInterface) { MetaInfo info; info.SetInfo("label", str.c_str()); - auto const& h_label = info.labels_.HostVector(); - ASSERT_EQ(h_label.size(), d_data.size()); + auto const& h_label = info.labels.HostView(); + ASSERT_EQ(h_label.Size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { - ASSERT_EQ(h_label[i], d_data[i]); + ASSERT_EQ(h_label(i), d_data[i]); } info.SetInfo("weight", str.c_str()); @@ -147,15 +147,15 @@ TEST(MetaInfo, DeviceExtend) { std::string str = PrepareData("HostCanRead()); lhs.num_row_ = kRows; rhs.num_row_ = kRows; lhs.Extend(rhs, true, true); ASSERT_EQ(lhs.num_row_, kRows * 2); - ASSERT_FALSE(lhs.labels_.HostCanRead()); + ASSERT_FALSE(lhs.labels.Data()->HostCanRead()); - ASSERT_FALSE(lhs.labels_.HostCanRead()); - ASSERT_FALSE(rhs.labels_.HostCanRead()); + ASSERT_FALSE(lhs.labels.Data()->HostCanRead()); + ASSERT_FALSE(rhs.labels.Data()->HostCanRead()); } } // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index cef727851cb5..04bb2c9e7e2c 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -16,30 +16,27 @@ namespace xgboost { inline void TestMetaInfoStridedData(int32_t device) { MetaInfo info; { - // label - HostDeviceVector labels; - labels.Resize(64); - auto& h_labels = labels.HostVector(); - std::iota(h_labels.begin(), h_labels.end(), 0.0f); - bool is_gpu = device >= 0; - if (is_gpu) { - labels.SetDevice(0); - } + // labels + linalg::Tensor labels; + labels.Reshape(4, 2, 3); + auto& h_label = labels.Data()->HostVector(); + std::iota(h_label.begin(), h_label.end(), 0.0); + auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All()); + ASSERT_EQ(t_labels.Shape().size(), 2); - auto t = linalg::TensorView{ - is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device}; - auto s = t.Slice(linalg::All(), 0); - - auto str = ArrayInterfaceStr(s); - ASSERT_EQ(s.Size(), 32); - - info.SetInfo("label", StringView{str}); - auto const& h_result = info.labels_.HostVector(); - ASSERT_EQ(h_result.size(), 32); - - for (auto v : h_result) { - ASSERT_EQ(static_cast(v) % 2, 0); - } + info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)}); + auto const& h_result = info.labels.View(-1); + ASSERT_EQ(h_result.Shape().size(), 2); + auto in_labels = labels.View(-1); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + auto tup = linalg::UnravelIndex(i, h_result.Shape()); + auto i0 = std::get<0>(tup); + auto i1 = std::get<1>(tup); + // Sliced at second dimension. + auto v_1 = in_labels(i0, 0, i1); + CHECK_EQ(v_0, v_1); + return v_0; + }); } { // qid diff --git a/tests/cpp/data/test_proxy_dmatrix.cu b/tests/cpp/data/test_proxy_dmatrix.cu index 19aa7a3ee90c..d9f315a8f144 100644 --- a/tests/cpp/data/test_proxy_dmatrix.cu +++ b/tests/cpp/data/test_proxy_dmatrix.cu @@ -23,7 +23,7 @@ TEST(ProxyDMatrix, DeviceData) { proxy.SetInfo("label", labels.c_str()); ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr)); - ASSERT_EQ(proxy.Info().labels_.Size(), kRows); + ASSERT_EQ(proxy.Info().labels.Size(), kRows); ASSERT_EQ(dmlc::get>(proxy.Adapter())->NumRows(), kRows); ASSERT_EQ( diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 40dd270a6b88..60bb181cc8a3 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -20,7 +20,7 @@ TEST(SimpleDMatrix, MetaInfo) { EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_col_, 5); EXPECT_EQ(dmat->Info().num_nonzero_, 6); - EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); + EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_); delete dmat; } @@ -258,7 +258,7 @@ TEST(SimpleDMatrix, Slice) { std::array ridxs {1, 3, 5}; std::unique_ptr out { p_m->Slice(ridxs) }; - ASSERT_EQ(out->Info().labels_.Size(), ridxs.size()); + ASSERT_EQ(out->Info().labels.Size(), ridxs.size()); ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size()); ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size()); ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 0ae69a67f00b..2cdef6641635 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -113,7 +113,7 @@ TEST(SparsePageDMatrix, MetaInfo) { EXPECT_EQ(dmat->Info().num_row_, 8ul); EXPECT_EQ(dmat->Info().num_col_, 5ul); EXPECT_EQ(dmat->Info().num_nonzero_, kEntries); - EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); + EXPECT_EQ(dmat->Info().labels.Size(), dmat->Info().num_row_); delete dmat; } diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 6a454f96df16..34c1a52d9dd2 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -105,7 +105,7 @@ TEST(GBTree, WrongUpdater) { auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - p_dmat->Info().labels_.Resize(kRows); + p_dmat->Info().labels.Reshape(kRows); auto learner = std::unique_ptr(Learner::Create({p_dmat})); // Hist can not be used for updating tree. @@ -126,7 +126,7 @@ TEST(GBTree, ChoosePredictor) { auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto& data = (*(p_dmat->GetBatches().begin())).data; - p_dmat->Info().labels_.Resize(kRows); + p_dmat->Info().labels.Reshape(kRows); auto learner = std::unique_ptr(Learner::Create({p_dmat})); learner->SetParams(Args{{"tree_method", "gpu_hist"}, {"gpu_id", "0"}}); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 0906d9ed87df..da627cdd1377 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -100,7 +100,8 @@ void CheckObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_.HostVector() = labels; + info.labels = + xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; info.weights_.HostVector() = weights; CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess); @@ -135,7 +136,8 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_.HostVector() = labels; + info.labels = + xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; @@ -149,7 +151,8 @@ xgboost::bst_float GetMetricEval(xgboost::Metric * metric, std::vector groups) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_.HostVector() = labels; + info.labels = + xgboost::linalg::Tensor{labels.begin(), labels.end(), {labels.size()}, -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; @@ -340,17 +343,18 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, if (with_label) { RandomDataGenerator gen(rows_, 1, 0); if (!float_label) { - gen.Lower(0).Upper(classes).GenerateDense(&out->Info().labels_); - auto& h_labels = out->Info().labels_.HostVector(); + gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data()); + out->Info().labels.Reshape(out->Info().labels.Size()); + auto& h_labels = out->Info().labels.Data()->HostVector(); for (auto& v : h_labels) { v = static_cast(static_cast(v)); } } else { - gen.GenerateDense(&out->Info().labels_); + gen.GenerateDense(out->Info().labels.Data()); } } if (device_ >= 0) { - out->Info().labels_.SetDevice(device_); + out->Info().labels.SetDevice(device_); out->Info().feature_types.SetDevice(device_); for (auto const& page : out->GetBatches()) { page.data.SetDevice(device_); @@ -520,7 +524,8 @@ std::unique_ptr CreateTrainedGBM( for (size_t i = 0; i < kRows; ++i) { labels[i] = i; } - p_dmat->Info().labels_.HostVector() = labels; + p_dmat->Info().labels = + linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; HostDeviceVector gpair; auto& h_gpair = gpair.HostVector(); h_gpair.resize(kRows); @@ -636,7 +641,7 @@ class RMMAllocator {}; void DeleteRMMResource(RMMAllocator* r) {} RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) { - return RMMAllocatorPtr(nullptr, DeleteRMMResource); + return {nullptr, DeleteRMMResource}; } #endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1 } // namespace xgboost diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc index a8ae7eeafa8d..8fd700a73442 100644 --- a/tests/cpp/metric/test_auc.cc +++ b/tests/cpp/metric/test_auc.cc @@ -21,10 +21,10 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { // Invalid dataset MetaInfo info; - info.labels_ = {0, 0}; + info.labels = linalg::Tensor{{0.0f, 0.0f}, {2}, -1}; float auc = metric->Eval({1, 1}, info, false); ASSERT_TRUE(std::isnan(auc)); - info.labels_ = HostDeviceVector{}; + *info.labels.Data() = HostDeviceVector{}; auc = metric->Eval(HostDeviceVector{}, info, false); ASSERT_TRUE(std::isnan(auc)); diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index dfb188b6be6f..d5d460a6833f 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -17,7 +17,7 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) HostDeviceVector predts; MetaInfo info; - auto &h_labels = info.labels_.HostVector(); + auto &h_labels = info.labels.Data()->HostVector(); auto &h_predts = predts.HostVector(); SimpleLCG lcg; diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc index 6f8ff28094cd..5a2c939e9315 100644 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ b/tests/cpp/metric/test_multiclass_metric.cc @@ -11,13 +11,14 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) HostDeviceVector predts; MetaInfo info; - auto &h_labels = info.labels_.HostVector(); auto &h_predts = predts.HostVector(); SimpleLCG lcg; size_t n_samples = 2048, n_classes = 4; - h_labels.resize(n_samples); + + info.labels.Reshape(n_samples); + auto &h_labels = info.labels.Data()->HostVector(); h_predts.resize(n_samples * n_classes); { diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 112670269e93..6f396ea76e62 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2021 XGBoost contributors */ #include #include @@ -293,8 +293,8 @@ TEST(Objective, CPU_vs_CUDA) { } auto& info = pdmat->Info(); - info.labels_.Resize(kRows); - auto& h_labels = info.labels_.HostVector(); + info.labels.Reshape(kRows); + auto& h_labels = info.labels.Data()->HostVector(); for (size_t i = 0; i < h_labels.size(); ++i) { h_labels[i] = 1 / (float)(i+1); } diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 1dc0e2deee24..b277da7d6eea 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -45,8 +45,8 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::unique_ptr learner; auto train = [&](std::string predictor, HostDeviceVector *out) { - auto &h_label = p_hist->Info().labels_.HostVector(); - h_label.resize(rows); + p_hist->Info().labels.Reshape(rows, 1); + auto &h_label = p_hist->Info().labels.Data()->HostVector(); for (size_t i = 0; i < rows; ++i) { h_label[i] = i % kClasses; diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 4a0b499fffbd..23859bc289d5 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -141,9 +141,8 @@ TEST(Learner, JsonModelIO) { size_t constexpr kRows = 8; int32_t constexpr kIters = 4; - std::shared_ptr p_dmat{ - RandomDataGenerator{kRows, 10, 0}.GenerateDMatrix()}; - p_dmat->Info().labels_.Resize(kRows); + std::shared_ptr p_dmat{RandomDataGenerator{kRows, 10, 0}.GenerateDMatrix()}; + p_dmat->Info().labels.Reshape(kRows); CHECK_NE(p_dmat->Info().num_col_, 0); { @@ -204,9 +203,8 @@ TEST(Learner, MultiThreadedPredict) { size_t constexpr kRows = 1000; size_t constexpr kCols = 100; - std::shared_ptr p_dmat{ - RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()}; - p_dmat->Info().labels_.Resize(kRows); + std::shared_ptr p_dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()}; + p_dmat->Info().labels.Reshape(kRows); CHECK_NE(p_dmat->Info().num_col_, 0); std::shared_ptr p_data{ @@ -240,7 +238,7 @@ TEST(Learner, BinaryModelIO) { size_t constexpr kRows = 8; int32_t constexpr kIters = 4; auto p_dmat = RandomDataGenerator{kRows, 10, 0}.GenerateDMatrix(); - p_dmat->Info().labels_.Resize(kRows); + p_dmat->Info().labels.Reshape(kRows); std::unique_ptr learner{Learner::Create({p_dmat})}; learner->SetParam("eval_metric", "rmsle"); @@ -279,7 +277,7 @@ TEST(Learner, GPUConfiguration) { for (size_t i = 0; i < labels.size(); ++i) { labels[i] = i; } - p_dmat->Info().labels_.HostVector() = labels; + p_dmat->Info().labels.Data()->HostVector() = labels; { std::unique_ptr learner {Learner::Create(mat)}; learner->SetParams({Arg{"booster", "gblinear"}, diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 3c8514be507a..38954f6387c6 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -204,8 +204,8 @@ class SerializationTest : public ::testing::Test { void SetUp() override { p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(); - p_dmat_->Info().labels_.Resize(kRows); - auto &h_labels = p_dmat_->Info().labels_.HostVector(); + p_dmat_->Info().labels.Reshape(kRows); + auto& h_labels = p_dmat_->Info().labels.Data()->HostVector(); xgboost::SimpleLCG gen(0); SimpleRealUniformDistribution dis(0.0f, 1.0f); @@ -219,6 +219,9 @@ class SerializationTest : public ::testing::Test { } }; +size_t constexpr SerializationTest::kRows; +size_t constexpr SerializationTest::kCols; + TEST_F(SerializationTest, Exact) { TestLearnerSerialization({{"booster", "gbtree"}, {"seed", "0"}, @@ -389,8 +392,8 @@ class LogitSerializationTest : public SerializationTest { p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(); std::shared_ptr p_dmat{p_dmat_}; - p_dmat->Info().labels_.Resize(kRows); - auto &h_labels = p_dmat->Info().labels_.HostVector(); + p_dmat->Info().labels.Reshape(kRows); + auto& h_labels = p_dmat->Info().labels.Data()->HostVector(); std::bernoulli_distribution flip(0.5); auto& rnd = common::GlobalRandom(); @@ -513,8 +516,8 @@ class MultiClassesSerializationTest : public SerializationTest { p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(); std::shared_ptr p_dmat{p_dmat_}; - p_dmat->Info().labels_.Resize(kRows); - auto &h_labels = p_dmat->Info().labels_.HostVector(); + p_dmat->Info().labels.Reshape(kRows); + auto &h_labels = p_dmat->Info().labels.Data()->HostVector(); std::uniform_int_distribution categorical(0, kClasses - 1); auto& rnd = common::GlobalRandom(); diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index a1feaacd484a..88878458133c 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -148,7 +148,8 @@ def test_pandas_categorical(self): assert not np.any(arr == -1.0) X = X["f0"] - with pytest.raises(ValueError): + y = y[:X.shape[0]] + with pytest.raises(ValueError, match=r".*enable_categorical.*"): xgb.DMatrix(X, y) Xy = xgb.DMatrix(X, y, enable_categorical=True)