From e46a59d973a569c2684ac1e21b9c246f3659f25e Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 26 Oct 2021 21:13:34 +0800 Subject: [PATCH 1/2] Implement a general array view. * Replace existing matrix and vector view. lint. Remove const too. Doc/Test. Include. Use it in AUC. Win build. Use int32_t. Use integral. force the same type. Use constexpr for old nvcc. Test for empty tensor. Rename to view. Format. Better document and perf. Address reviewer's comment. --- include/xgboost/linalg.h | 348 +++++++++++++++++++++++------- include/xgboost/tree_updater.h | 2 +- src/gbm/gblinear.cc | 5 +- src/gbm/gbtree.cc | 13 +- src/gbm/gbtree.cu | 11 +- src/metric/auc.cc | 31 +-- src/tree/updater_gpu_hist.cu | 16 +- src/tree/updater_quantile_hist.cc | 6 +- src/tree/updater_quantile_hist.h | 4 +- tests/cpp/common/test_linalg.cc | 90 +++++++- tests/cpp/tree/test_gpu_hist.cu | 6 +- 11 files changed, 402 insertions(+), 130 deletions(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 5bd6f913a0ee..dcdccc1ed3af 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -1,113 +1,301 @@ /*! - * Copyright 2021 by Contributors + * Copyright 2021 by XGBoost Contributors * \file linalg.h - * \brief Linear algebra related utilities. + * \brief Linear algebra related utilities. */ #ifndef XGBOOST_LINALG_H_ #define XGBOOST_LINALG_H_ -#include -#include +#include #include +#include -#include #include +#include +#include #include #include namespace xgboost { -/*! - * \brief A view over a matrix on contiguous storage. - * - * \tparam T data type of matrix +namespace linalg { +namespace detail { +template +constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head) { + assert(dim < D); + return n + head * strides[dim]; +} + +template +constexpr size_t Offset(S (&strides)[D], size_t n, size_t dim, Head head, Tail &&...rest) { + assert(dim < D); + return Offset(strides, n + (head * strides[dim]), dim + 1, rest...); +} + +struct AllTag {}; +struct IntTag {}; + +/** + * \brief Calculate the dimension of sliced tensor. */ -template class MatrixView { - int32_t device_; - common::Span values_; - size_t strides_[2]; - size_t shape_[2]; - - template static auto InferValues(Vec *vec, int32_t device) { - return device == GenericParameter::kCpuId ? vec->HostSpan() - : vec->DeviceSpan(); - } +template +constexpr int32_t CalcSliceDim() { + return std::is_same::value ? 0 : 1; +} - public: - /*! - * \param vec storage. - * \param strides Strides for matrix. - * \param shape Rows and columns. - * \param device Where the data is stored in. - */ - MatrixView(HostDeviceVector *vec, std::array strides, - std::array shape, int32_t device) - : device_{device}, values_{InferValues(vec, device)} { - std::copy(strides.cbegin(), strides.cend(), strides_); - std::copy(shape.cbegin(), shape.cend(), shape_); - } - MatrixView(HostDeviceVector> const *vec, - std::array strides, std::array shape, - int32_t device) - : device_{device}, values_{InferValues(vec, device)} { - std::copy(strides.cbegin(), strides.cend(), strides_); - std::copy(shape.cbegin(), shape.cend(), shape_); +template +constexpr std::enable_if_t CalcSliceDim() { + return CalcSliceDim() + CalcSliceDim(); +} + +template +constexpr size_t CalcSize(size_t (&shape)[D]) { + size_t size = 1; + for (auto d : shape) { + size *= d; } - /*! \brief Row major constructor. */ - MatrixView(HostDeviceVector *vec, std::array shape, - int32_t device) - : device_{device}, values_{InferValues(vec, device)} { - std::copy(shape.cbegin(), shape.cend(), shape_); - strides_[0] = shape[1]; - strides_[1] = 1; + return size; +} + +template +using RemoveCRType = std::remove_const_t>; + +template +using IndexToTag = std::conditional_t>::value, IntTag, AllTag>; + +template +XGBOOST_DEVICE constexpr auto UnrollLoop(Fn fn) { +#if defined __CUDA_ARCH__ +#pragma unroll(n) +#endif // defined __CUDA_ARCH__ + for (int32_t i = 0; i < n; ++i) { + fn(i); } - MatrixView(std::vector *vec, std::array shape) - : device_{GenericParameter::kCpuId}, values_{*vec} { - CHECK_EQ(vec->size(), shape[0] * shape[1]); - std::copy(shape.cbegin(), shape.cend(), shape_); - strides_[0] = shape[1]; - strides_[1] = 1; +} +} // namespace detail + +/** + * \brief Specify all elements in the axis is used for slice. + */ +constexpr detail::AllTag All() { return {}; } + +/** + * \brief A tensor view with static type and shape. It implements indexing and slicing. + * + * Most of the algorithms in XGBoost are implemented for both CPU and GPU without using + * much linear algebra routines, this class is a helper intended to ease some high level + * operations like indexing into prediction tensor or gradient matrix. It can be passed + * into CUDA kernel as normal argument for GPU algorithms. + */ +template +class TensorView { + public: + using ShapeT = size_t[kDim]; + using StrideT = ShapeT; + + private: + StrideT stride_{1}; + ShapeT shape_{0}; + common::Span data_; + T* ptr_{nullptr}; // pointer of data_ to avoid bound check. + + size_t size_{0}; + int32_t device_{-1}; + + // Unlike `Tensor`, the data_ can have arbitrary size since this is just a view. + XGBOOST_DEVICE void CalcSize() { + if (data_.empty()) { + size_ = 0; + } else { + size_ = detail::CalcSize(shape_); + } } - MatrixView(HostDeviceVector> const *vec, - std::array shape, int32_t device) - : device_{device}, values_{InferValues(vec, device)} { - std::copy(shape.cbegin(), shape.cend(), shape_); - strides_[0] = shape[1]; - strides_[1] = 1; + + struct SliceHelper { + size_t old_dim; + size_t new_dim; + size_t offset; + }; + + template + XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D], + size_t new_stride[D], detail::AllTag) const { + new_stride[new_dim] = stride_[old_dim]; + new_shape[new_dim] = shape_[old_dim]; + return {old_dim + 1, new_dim + 1, 0}; } - XGBOOST_DEVICE T const &operator()(size_t r, size_t c) const { - return values_[strides_[0] * r + strides_[1] * c]; + template + XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D], + size_t new_stride[D], detail::AllTag, + S &&...slices) const { + new_stride[new_dim] = stride_[old_dim]; + new_shape[new_dim] = shape_[old_dim]; + return MakeSliceDim(old_dim + 1, new_dim + 1, new_shape, new_stride, slices...); } - XGBOOST_DEVICE T &operator()(size_t r, size_t c) { - return values_[strides_[0] * r + strides_[1] * c]; + + template + XGBOOST_DEVICE SliceHelper MakeSliceDim(size_t old_dim, size_t new_dim, size_t new_shape[D], + size_t new_stride[D], Index i) const { + return {old_dim + 1, new_dim, stride_[old_dim] * i}; } - auto Strides() const { return strides_; } - auto Shape() const { return shape_; } - auto Values() const { return values_; } - auto Size() const { return shape_[0] * shape_[1]; } - auto DeviceIdx() const { return device_; } -}; + template + XGBOOST_DEVICE std::enable_if_t::value, SliceHelper> MakeSliceDim( + size_t old_dim, size_t new_dim, size_t new_shape[D], size_t new_stride[D], Index i, + S &&...slices) const { + auto offset = stride_[old_dim] * i; + auto res = MakeSliceDim(old_dim + 1, new_dim, new_shape, new_stride, slices...); + return {res.old_dim, res.new_dim, res.offset + offset}; + } -/*! \brief A slice for 1 column of MatrixView. Can be extended to row if needed. */ -template class VectorView { - MatrixView matrix_; - size_t column_; + public: + size_t constexpr static kValueSize = sizeof(T); + size_t constexpr static kDimension = kDim; public: - explicit VectorView(MatrixView matrix, size_t column) - : matrix_{std::move(matrix)}, column_{column} {} + /** + * \brief Create a tensor with data and shape. + * + * \tparam I Type of the shape array element. + * \tparam D Size of the shape array, can be lesser than or equal to tensor dimension. + * + * \param data Raw data input, can be const if this tensor has const type in its + * template parameter. + * \param shape shape of the tensor + * \param device Device ordinal + */ + template + XGBOOST_DEVICE TensorView(common::Span data, I const (&shape)[D], int32_t device) + : data_{data}, ptr_{data_.data()}, device_{device} { + static_assert(D > 0 && D <= kDim, "Invalid shape."); + // shape + detail::UnrollLoop([&](auto i) { shape_[i] = shape[i]; }); + for (auto i = D; i < kDim; ++i) { + shape_[i] = 1; + } + // stride + stride_[kDim - 1] = 1; + for (int32_t s = kDim - 2; s >= 0; --s) { + stride_[s] = shape_[s + 1] * stride_[s + 1]; + } + this->CalcSize(); + }; - XGBOOST_DEVICE T &operator[](size_t i) { - return matrix_(i, column_); + /** + * \brief Create a tensor with data, shape and strides. Don't use this constructor if + * stride can be calculated from shape. + */ + template + XGBOOST_DEVICE TensorView(common::Span data, I const (&shape)[D], I const (&stride)[D], + int32_t device) + : data_{data}, ptr_{data_.data()}, device_{device} { + static_assert(D == kDim, "Invalid shape & stride."); + detail::UnrollLoop([&](auto i) { + shape_[i] = shape[i]; + stride_[i] = stride[i]; + }); + this->CalcSize(); + }; + + XGBOOST_DEVICE TensorView(TensorView const &that) + : data_{that.data_}, ptr_{data_.data()}, size_{that.size_}, device_{that.device_} { + detail::UnrollLoop([&](auto i) { + stride_[i] = that.stride_[i]; + shape_[i] = that.shape_[i]; + }); } - XGBOOST_DEVICE T const &operator[](size_t i) const { - return matrix_(i, column_); + /** + * \brief Index the tensor to obtain a scalar value. + * + * \code + * + * // Create a 3-dim tensor. + * Tensor t {data, shape, 0}; + * float pi = 3.14159; + * t(1, 2, 3) = pi; + * ASSERT_EQ(t(1, 2, 3), pi); + * + * \endcode + */ + template + XGBOOST_DEVICE T &operator()(Index &&...index) { + static_assert(sizeof...(index) <= kDim, "Invalid index."); + size_t offset = detail::Offset(stride_, 0ul, 0ul, index...); + return ptr_[offset]; + } + /** + * \brief Index the tensor to obtain a scalar value. + */ + template + XGBOOST_DEVICE T const &operator()(Index &&...index) const { + static_assert(sizeof...(index) <= kDim, "Invalid index."); + size_t offset = detail::Offset(stride_, 0ul, 0ul, index...); + return ptr_[offset]; } - size_t Size() { return matrix_.Shape()[0]; } - int32_t DeviceIdx() const { return matrix_.DeviceIdx(); } + /** + * \brief Slice the tensor. The returned tensor has inferred dim and shape. + * + * \code + * + * // Create a 3-dim tensor. + * Tensor t {data, shape, 0}; + * // s has 2 dimensions (matrix) + * auto s = t.Slice(1, All(), All()); + * + * \endcode + */ + template + XGBOOST_DEVICE auto Slice(S &&...slices) const { + static_assert(sizeof...(slices) <= kDim, "Invalid slice."); + int32_t constexpr kNewDim{detail::CalcSliceDim...>()}; + size_t new_shape[kNewDim]; + size_t new_stride[kNewDim]; + auto res = MakeSliceDim(size_t(0), size_t(0), new_shape, new_stride, slices...); + // ret is a different type due to changed dimension, so we can not access its private + // fields. + TensorView ret{data_.subspan(data_.empty() ? 0 : res.offset), new_shape, new_stride, + device_}; + return ret; + } + + XGBOOST_DEVICE auto Shape() const { return common::Span{shape_}; } + /** + * Get the shape for i^th dimension + */ + XGBOOST_DEVICE auto Shape(size_t i) const { return shape_[i]; } + XGBOOST_DEVICE auto Stride() const { return common::Span{stride_}; } + /** + * Get the stride for i^th dimension, stride is specified as number of items instead of bytes. + */ + XGBOOST_DEVICE auto Stride(size_t i) const { return stride_[i]; } + + XGBOOST_DEVICE auto cbegin() const { return data_.cbegin(); } // NOLINT + XGBOOST_DEVICE auto cend() const { return data_.cend(); } // NOLINT + XGBOOST_DEVICE auto begin() { return data_.begin(); } // NOLINT + XGBOOST_DEVICE auto end() { return data_.end(); } // NOLINT + + XGBOOST_DEVICE size_t Size() const { return size_; } + XGBOOST_DEVICE auto Values() const { return data_; } + XGBOOST_DEVICE auto DeviceIdx() const { return device_; } }; -} // namespace xgboost + +/** + * \brief A view over a vector, specialization of Tensor + * + * \tparam T data type of vector + */ +template +using VectorView = TensorView; + +/** + * \brief A view over a matrix, specialization of Tensor. + * + * \tparam T data type of matrix + */ +template +using MatrixView = TensorView; +} // namespace linalg +} // namespace xgboost #endif // XGBOOST_LINALG_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index f36005a9a69e..0d2b2ed5436c 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -72,7 +72,7 @@ class TreeUpdater : public Configurable { * updated by the time this function returns. */ virtual bool UpdatePredictionCache(const DMatrix * /*data*/, - VectorView /*out_preds*/) { + linalg::VectorView /*out_preds*/) { return false; } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 8da1f67f40e2..c9059f436f21 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -243,7 +243,10 @@ class GBLinear : public GradientBooster { // The bias is the last weight out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0); auto n_groups = learner_model_param_->num_output_group; - MatrixView scores{out_scores, {learner_model_param_->num_feature, n_groups}}; + linalg::TensorView scores{ + *out_scores, + {learner_model_param_->num_feature, n_groups}, + GenericParameter::kCpuId}; for (size_t i = 0; i < learner_model_param_->num_feature; ++i) { for (bst_group_t g = 0; g < n_groups; ++g) { scores(i, g) = model_[i][g]; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 859e5ba9d3ad..b2145e02f78f 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -229,16 +229,19 @@ void GBTree::DoBoost(DMatrix* p_fmat, auto device = tparam_.tree_method != TreeMethod::kGPUHist ? GenericParameter::kCpuId : generic_param_->gpu_id; - auto out = MatrixView( - &predt->predictions, - {static_cast(p_fmat->Info().num_row_), static_cast(ngroup)}, device); + auto out = linalg::TensorView{ + device == GenericParameter::kCpuId ? predt->predictions.HostSpan() + : predt->predictions.DeviceSpan(), + {static_cast(p_fmat->Info().num_row_), + static_cast(ngroup)}, + device}; CHECK_NE(ngroup, 0); if (ngroup == 1) { std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - auto v_predt = VectorView{out, 0}; + auto v_predt = out.Slice(linalg::All(), 0); if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 && updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) { @@ -257,7 +260,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, BoostNewTrees(&tmp, p_fmat, gid, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - auto v_predt = VectorView{out, static_cast(gid)}; + auto v_predt = out.Slice(linalg::All(), gid); if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && num_new_trees == 1 && updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) { diff --git a/src/gbm/gbtree.cu b/src/gbm/gbtree.cu index eca8f4dbd269..0b81fff23e5c 100644 --- a/src/gbm/gbtree.cu +++ b/src/gbm/gbtree.cu @@ -12,15 +12,14 @@ namespace gbm { void GPUCopyGradient(HostDeviceVector const *in_gpair, bst_group_t n_groups, bst_group_t group_id, HostDeviceVector *out_gpair) { - MatrixView in{ - in_gpair, - {n_groups, 1ul}, + auto mat = linalg::TensorView( + in_gpair->ConstDeviceSpan(), {in_gpair->Size() / n_groups, static_cast(n_groups)}, - in_gpair->DeviceIdx()}; - auto v_in = VectorView{in, group_id}; + in_gpair->DeviceIdx()); + auto v_in = mat.Slice(linalg::All(), group_id); out_gpair->Resize(v_in.Size()); auto d_out = out_gpair->DeviceSpan(); - dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in[i]; }); + dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in(i); }); } void GPUDartPredictInc(common::Span out_predts, diff --git a/src/metric/auc.cc b/src/metric/auc.cc index b657c72ea6d0..ec8b6ee01bdd 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -13,6 +13,7 @@ #include #include "rabit/rabit.h" +#include "xgboost/linalg.h" #include "xgboost/host_device_vector.h" #include "xgboost/metric.h" @@ -83,41 +84,45 @@ double MultiClassOVR(common::Span predts, MetaInfo const &info, CHECK_NE(n_classes, 0); auto const &labels = info.labels_.ConstHostVector(); - std::vector results(n_classes * 3, 0); - auto s_results = common::Span(results); - - auto local_area = s_results.subspan(0, n_classes); - auto tp = s_results.subspan(n_classes, n_classes); - auto auc = s_results.subspan(2 * n_classes, n_classes); + std::vector results_storage(n_classes * 3, 0); + linalg::TensorView results(results_storage, + {n_classes, static_cast(3)}, + GenericParameter::kCpuId); + auto local_area = results.Slice(linalg::All(), 0); + auto tp = results.Slice(linalg::All(), 1); + auto auc = results.Slice(linalg::All(), 2); auto weights = OptionalWeights{info.weights_.ConstHostSpan()}; + auto predts_t = linalg::TensorView( + predts, {static_cast(info.num_row_), n_classes}, + GenericParameter::kCpuId); if (!info.labels_.Empty()) { common::ParallelFor(n_classes, n_threads, [&](auto c) { std::vector proba(info.labels_.Size()); std::vector response(info.labels_.Size()); for (size_t i = 0; i < proba.size(); ++i) { - proba[i] = predts[i * n_classes + c]; + proba[i] = predts_t(i, c); response[i] = labels[i] == c ? 1.0f : 0.0; } double fp; - std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights); - local_area[c] = fp * tp[c]; + std::tie(fp, tp(c), auc(c)) = binary_auc(proba, response, weights); + local_area(c) = fp * tp(c); }); } // we have 2 averages going in here, first is among workers, second is among // classes. allreduce sums up fp/tp auc for each class. - rabit::Allreduce(results.data(), results.size()); + rabit::Allreduce(results.Values().data(), results.Values().size()); double auc_sum{0}; double tp_sum{0}; for (size_t c = 0; c < n_classes; ++c) { - if (local_area[c] != 0) { + if (local_area(c) != 0) { // normalize and weight it by prevalence. After allreduce, `local_area` // means the total covered area (not area under curve, rather it's the // accessible area for each worker) for each class. - auc_sum += auc[c] / local_area[c] * tp[c]; - tp_sum += tp[c]; + auc_sum += auc(c) / local_area(c) * tp(c); + tp_sum += tp(c); } else { auc_sum = std::numeric_limits::quiet_NaN(); break; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index cbe63d243da4..fa9447bd8078 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -496,7 +496,7 @@ struct GPUHistMakerDevice { }); } - void UpdatePredictionCache(VectorView out_preds_d) { + void UpdatePredictionCache(linalg::VectorView out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); CHECK_EQ(out_preds_d.DeviceIdx(), device_id); auto d_ridx = row_partitioner->GetRows(); @@ -512,13 +512,13 @@ struct GPUHistMakerDevice { auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto evaluator = tree_evaluator.GetEvaluator(); - dh::LaunchN(d_ridx.size(), [=] __device__(int local_idx) { + dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__( + int local_idx) mutable { int pos = d_position[local_idx]; bst_float weight = evaluator.CalcWeight( pos, param_d, GradStats{d_node_sum_gradients[pos]}); static_assert(!std::is_const::value, ""); - auto v_predt = out_preds_d; // for some reason out_preds_d is const by both nvcc and clang. - v_predt[d_ridx[local_idx]] += weight * param_d.learning_rate; + out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate; }); row_partitioner.reset(); } @@ -832,7 +832,8 @@ class GPUHistMakerSpecialised { maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); } - bool UpdatePredictionCache(const DMatrix* data, VectorView p_out_preds) { + bool UpdatePredictionCache(const DMatrix *data, + linalg::VectorView p_out_preds) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } @@ -916,8 +917,9 @@ class GPUHistMaker : public TreeUpdater { } } - bool UpdatePredictionCache(const DMatrix *data, - VectorView p_out_preds) override { + bool + UpdatePredictionCache(const DMatrix *data, + linalg::VectorView p_out_preds) override { if (hist_maker_param_.single_precision_histogram) { return float_maker_->UpdatePredictionCache(data, p_out_preds); } else { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 19c300b30672..8471214d5c3d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -105,7 +105,7 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, } bool QuantileHistMaker::UpdatePredictionCache( - const DMatrix* data, VectorView out_preds) { + const DMatrix* data, linalg::VectorView out_preds) { if (hist_maker_param_.single_precision_histogram && float_builder_) { return float_builder_->UpdatePredictionCache(data, out_preds); } else if (double_builder_) { @@ -319,7 +319,7 @@ void QuantileHistMaker::Builder::Update( template bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, - VectorView out_preds) { + linalg::VectorView out_preds) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || @@ -352,7 +352,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( leaf_value = (*p_last_tree_)[nid].LeafValue(); for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { - out_preds[*it] += leaf_value; + out_preds(*it) += leaf_value; } } }); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 9654ab00a7c0..feaa51544afd 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -105,7 +105,7 @@ class QuantileHistMaker: public TreeUpdater { const std::vector& trees) override; bool UpdatePredictionCache(const DMatrix *data, - VectorView out_preds) override; + linalg::VectorView out_preds) override; void LoadConfig(Json const& in) override { auto const& config = get(in); @@ -171,7 +171,7 @@ class QuantileHistMaker: public TreeUpdater { RegTree* p_tree); bool UpdatePredictionCache(const DMatrix* data, - VectorView out_preds); + linalg::VectorView out_preds); protected: // initialize temp data structure diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index 935e82dde107..e4ca2d86594b 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -1,18 +1,19 @@ #include +#include #include + #include namespace xgboost { - +namespace linalg { auto MakeMatrixFromTest(HostDeviceVector *storage, size_t n_rows, size_t n_cols) { storage->Resize(n_rows * n_cols); - auto& h_storage = storage->HostVector(); + auto &h_storage = storage->HostVector(); std::iota(h_storage.begin(), h_storage.end(), 0); - auto m = MatrixView{storage, {n_cols, 1}, {n_rows, n_cols}, -1}; + auto m = linalg::TensorView{h_storage, {n_rows, static_cast(n_cols)}, -1}; return m; - } TEST(Linalg, Matrix) { @@ -28,11 +29,84 @@ TEST(Linalg, Vector) { size_t kRows = 31, kCols = 77; HostDeviceVector storage; auto m = MakeMatrixFromTest(&storage, kRows, kCols); - auto v = VectorView(m, 3); + auto v = m.Slice(linalg::All(), 3); for (size_t i = 0; i < v.Size(); ++i) { - ASSERT_EQ(v[i], m(i, 3)); + ASSERT_EQ(v(i), m(i, 3)); + } + + ASSERT_EQ(v(0), 3); +} + +TEST(Linalg, Tensor) { + std::vector data(2 * 3 * 4, 0); + std::iota(data.begin(), data.end(), 0); + + TensorView t{data, {2, 3, 4}, -1}; + ASSERT_EQ(t.Shape()[0], 2); + ASSERT_EQ(t.Shape()[1], 3); + ASSERT_EQ(t.Shape()[2], 4); + + float v = t(0, 1, 2); + ASSERT_EQ(v, 6); + + auto s = t.Slice(1, All(), All()); + ASSERT_EQ(s.Shape().size(), 2); + ASSERT_EQ(s.Shape()[0], 3); + ASSERT_EQ(s.Shape()[1], 4); + + std::vector> sol{ + {12.0, 13.0, 14.0, 15.0}, {16.0, 17.0, 18.0, 19.0}, {20.0, 21.0, 22.0, 23.0}}; + for (size_t i = 0; i < s.Shape()[0]; ++i) { + for (size_t j = 0; j < s.Shape()[1]; ++j) { + ASSERT_EQ(s(i, j), sol[i][j]); + } } - ASSERT_EQ(v[0], 3); + { + // as vector + TensorView vec{data, {data.size()}, -1}; + ASSERT_EQ(vec.Size(), data.size()); + ASSERT_EQ(vec.Shape(0), data.size()); + ASSERT_EQ(vec.Shape().size(), 1); + for (size_t i = 0; i < data.size(); ++i) { + ASSERT_EQ(vec(i), data[i]); + } + } + + { + // as matrix + TensorView mat(data, {6, 4}, -1); + auto s = mat.Slice(2, All()); + ASSERT_EQ(s.Shape().size(), 1); + s = mat.Slice(All(), 1); + ASSERT_EQ(s.Shape().size(), 1); + } + + { + // assignment + TensorView t{data, {2, 3, 4}, 0}; + double pi = 3.14159; + t(1, 2, 3) = pi; + ASSERT_EQ(t(1, 2, 3), pi); + } + + { + // Don't assign the initial dimension, tensor should be able to deduce the correct dim + // for Slice. + TensorView t{data, {2, 3, 4}, 0}; + auto s = t.Slice(1, 2, All()); + static_assert(decltype(s)::kDimension == 1, ""); + } +} + +TEST(Linalg, Empty) { + auto t = TensorView{{}, {0, 3}, GenericParameter::kCpuId}; + for (int32_t i : {0, 1, 2}) { + auto s = t.Slice(All(), i); + ASSERT_EQ(s.Size(), 0); + ASSERT_EQ(s.Shape().size(), 1); + ASSERT_EQ(s.Shape(0), 0); + } } -} // namespace xgboost +} // namespace linalg +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 72c22539679f..6c8676e6ac61 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -399,10 +399,8 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, hist_maker.Configure(args, &generic_param); hist_maker.Update(gpair, dmat, {tree}); - hist_maker.UpdatePredictionCache( - dmat, - VectorView{ - MatrixView(preds, {preds->Size(), 1}, preds->DeviceIdx()), 0}); + auto cache = linalg::VectorView{preds->DeviceSpan(), {preds->Size()}, 0}; + hist_maker.UpdatePredictionCache(dmat, cache); } TEST(GpuHist, UniformSampling) { From b4fccbba908f789faac52b09ddf0691d9f21a9aa Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 2 Nov 2021 17:55:16 +0800 Subject: [PATCH 2/2] tidy. --- include/xgboost/linalg.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index dcdccc1ed3af..a801228e903e 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -65,7 +65,7 @@ using IndexToTag = std::conditional_t>::value, template XGBOOST_DEVICE constexpr auto UnrollLoop(Fn fn) { #if defined __CUDA_ARCH__ -#pragma unroll(n) +#pragma unroll n #endif // defined __CUDA_ARCH__ for (int32_t i = 0; i < n; ++i) { fn(i);