From 2dc4cb18e9a2b09bff6d7f9a50cf19ec6af15ce4 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 26 Jul 2022 23:35:52 +0800 Subject: [PATCH 01/10] Support CPU input for device `QuantileDMatrix`. - Copy `GHistIndexMatrix` to `Ellpack` when needed. --- src/common/hist_util.cuh | 22 +++++ src/data/ellpack_page.cu | 85 +++++++++++++++++-- src/data/ellpack_page.cuh | 7 ++ src/data/gradient_index.cc | 1 + src/data/iterative_dmatrix.cc | 7 +- src/data/iterative_dmatrix.cu | 12 ++- src/data/iterative_dmatrix.h | 28 +++++- tests/cpp/data/test_ellpack_page.cu | 39 +++++++++ .../test_device_quantile_dmatrix.py | 28 ++++++ 9 files changed, 219 insertions(+), 10 deletions(-) diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 8fac9fca2f14..7dd62b38216e 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -17,6 +17,28 @@ namespace xgboost { namespace common { +namespace cuda { +/** + * copy and paste of the host version, we can't make it a __host__ __device__ function as + * the fn might be a host only or device only callable object, which is not allowed by nvcc. + */ +template +auto __device__ DispatchBinType(BinTypeSize type, Fn&& fn) { + switch (type) { + case kUint8BinsTypeSize: { + return fn(uint8_t{}); + } + case kUint16BinsTypeSize: { + return fn(uint16_t{}); + } + case kUint32BinsTypeSize: { + return fn(uint32_t{}); + } + } + SPAN_CHECK(false); + return fn(uint32_t{}); +} +} // namespace cuda namespace detail { struct EntryCompareOp { diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 14a1b2bbf172..0834bb6fd7a3 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -1,14 +1,16 @@ /*! - * Copyright 2019-2020 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ -#include #include #include + #include "../common/categorical.h" #include "../common/hist_util.cuh" #include "../common/random.h" #include "./ellpack_page.cuh" #include "device_adapter.cuh" +#include "gradient_index.h" +#include "xgboost/data.h" namespace xgboost { @@ -32,7 +34,7 @@ __global__ void CompressBinEllpackKernel( const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data const float* __restrict__ cuts, // HistogramCuts::cut_values_ - const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_ + const uint32_t* __restrict__ cut_ptrs, // HistogramCuts::cut_ptrs_ common::Span feature_types, size_t base_row, // batch_row_begin size_t n_rows, @@ -50,8 +52,8 @@ __global__ void CompressBinEllpackKernel( int feature = entry.index; float fvalue = entry.fvalue; // {feature_cuts, ncuts} forms the array of cuts of `feature'. - const float* feature_cuts = &cuts[cut_rows[feature]]; - int ncuts = cut_rows[feature + 1] - cut_rows[feature]; + const float* feature_cuts = &cuts[cut_ptrs[feature]]; + int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature]; bool is_cat = common::IsCat(feature_types, ifeature); // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] @@ -69,7 +71,7 @@ __global__ void CompressBinEllpackKernel( bin = ncuts - 1; } // Add the number of bins in previous features. - bin += cut_rows[feature]; + bin += cut_ptrs[feature]; } // Write to gidx buffer. wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); @@ -284,6 +286,77 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) +namespace { +void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span d_row_ptr, + size_t row_stride, common::CompressedByteT* d_compressed_buffer) { + dh::device_vector data(page.index.begin(), page.index.end()); + auto d_data = dh::ToSpan(data); + + dh::device_vector csc_indptr(page.index.Offset(), + page.index.Offset() + page.index.OffsetSize()); + auto d_csc_indptr = dh::ToSpan(csc_indptr); + + auto bin_type = page.index.GetBinTypeSize(); + common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value + + dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable { + auto ridx = idx / row_stride; + auto ifeature = idx % row_stride; + + auto r_begin = d_row_ptr[ridx]; + auto r_end = d_row_ptr[ridx + 1]; + size_t rsize = r_end - r_begin; + + if (ifeature >= rsize) { + return; + } + + size_t offset = 0; + if (!d_csc_indptr.empty()) { + // is dense, ifeature is the actual feature index. + offset = d_csc_indptr[ifeature]; + } + common::cuda::DispatchBinType(bin_type, [&](auto t) { + using T = decltype(t); + auto ptr = reinterpret_cast(d_data.data()); + auto bin_idx = ptr[r_begin + ifeature] + offset; + writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, ridx * row_stride + ifeature); + }); + }); +} + +void RowCountsFromIndptr(common::Span d_row_ptr, common::Span row_counts) { + dh::LaunchN(row_counts.size(), + [=] XGBOOST_DEVICE(size_t i) { row_counts[i] = d_row_ptr[i + 1] - d_row_ptr[i]; }); +} +} // anonymous namespace + +EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, + common::Span ft) + : is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} { + auto it = common::MakeIndexTransformIter( + [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); + row_stride = *std::max_element(it, it + page.Size()); + + CHECK_GE(ctx->gpu_id, 0); + monitor_.Start("InitCompressedData"); + InitCompressedData(ctx->gpu_id); + monitor_.Stop("InitCompressedData"); + + // copy gidx + auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft); + common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer(); + dh::device_vector row_ptr(page.row_ptr); + auto d_row_ptr = dh::ToSpan(row_ptr); + CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer); + + // write null value + dh::device_vector row_counts(page.Size()); + auto row_counts_span = dh::ToSpan(row_counts); + RowCountsFromIndptr(d_row_ptr, row_counts_span); + WriteNullValues(this, ctx->gpu_id, row_counts_span); +} + // A functor that copies the data from one EllpackPage to another. struct CopyPage { common::CompressedBufferWriter cbw; diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 7a2020c8b0b4..75d394e30e57 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -116,6 +116,8 @@ struct EllpackDeviceAccessor { }; +class GHistIndexMatrix; + class EllpackPageImpl { public: /*! @@ -154,6 +156,11 @@ class EllpackPageImpl { common::Span row_counts_span, common::Span feature_types, size_t row_stride, size_t n_rows, common::HistogramCuts const& cuts); + /** + * \brief Constructor from an existing CPU gradient index. + */ + explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, + common::Span ft); /*! \brief Copy the elements of the given ELLPACK page into this page. * diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 08f03f1a14d3..e34db5495952 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -66,6 +66,7 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts & max_num_bins(max_bin_per_feat), isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} + GHistIndexMatrix::~GHistIndexMatrix() = default; void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 4d8c602842df..e43fcccbc4b2 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -205,7 +205,12 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, BatchSet IterativeDMatrix::GetGradientIndex(BatchParam const& param) { CheckParam(param); - CHECK(ghist_) << "Not initialized with CPU data"; + CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training. +Possible solutions: +- Use `DMatrix` instead. +- Use CPU input for `QuantileDMatrix`. +- Run training on GPU. +)"; auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ghist_)); return BatchSet(begin_iter); diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index b2159e978522..901662852a15 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -168,7 +168,17 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing, BatchSet IterativeDMatrix::GetEllpackBatches(BatchParam const& param) { CheckParam(param); - CHECK(ellpack_) << "Not initialized with GPU data"; + if (!ellpack_ && !ghist_) { + LOG(FATAL) << "`QuantileDMatrix` not initialized."; + } + if (!ellpack_ && ghist_) { + ellpack_.reset(new EllpackPage()); + this->ctx_.gpu_id = param.gpu_id; + this->Info().feature_types.SetDevice(param.gpu_id); + *ellpack_->Impl() = + EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan()); + } + CHECK(ellpack_); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_)); return BatchSet(begin_iter); } diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 37c53e782973..49c1258ddb01 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -22,7 +22,30 @@ class HistogramCuts; } namespace data { - +/** + * \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its + * construction process. + * + * `QuantileDMatrix` is an intermediate storage for quantilization results including + * quantile cuts and histogram index. Quantilization is designed to be performed on stream + * of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work + * with batches of data. During initializaion, it will walk through the data multiple + * times iteratively in order to perform quantilization. This design can help us reduce + * memory usage significantly by avoiding data concatenation along with removing the CSR + * matrix `SparsePage`. However, it has its limitation (can be fixed if needed): + * + * - It's only supported by hist tree method (both CPU and GPU) since approx requires a + * re-calculation of quantiles for each iteration. We can fix this by retaining a + * reference to the callback if there are feature requests. + * + * - The CPU format and the GPU format are different, the former uses a CSR + CSC for + * histogram index while the latter uses only Ellpack. This results into a design that + * we can obtain the GPU format from CPU but not the other way around since we can't + * recover the CSC from Ellpack. More concretely, if users want to construct a CPU + * version of `QuantileDMatrix`, input data must be on CPU. However, if users want to + * have a GPU version of `QuantileDMatrix`, data can be anywhere. We can fix this by + * retaining the feature index information in ellpack if there are feature requests. + */ class IterativeDMatrix : public DMatrix { MetaInfo info_; Context ctx_; @@ -40,7 +63,8 @@ class IterativeDMatrix : public DMatrix { LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin << " vs. " << batch_param_.max_bin; } - CHECK(!param.regen) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`."; + CHECK(!param.regen && param.hess.empty()) + << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`."; } template diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index a67ab1d59f02..f2bf41981703 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -236,4 +236,43 @@ TEST(EllpackPage, Compact) { } } } + +TEST(EllpackPage, FromGHistIndex) { + auto test = [&](float sparsity) { + // Only testing with small sample size as the cuts might be different between host and + // device. + size_t n_samples{128}, n_features{13}; + Context ctx; + ctx.gpu_id = 0; + auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true); + std::unique_ptr from_ghist; + ASSERT_TRUE(Xy->SingleColBlock()); + for (auto const& page : Xy->GetBatches(BatchParam{17, 0.6})) { + from_ghist.reset(new EllpackPageImpl{&ctx, page, {}}); + } + + for (auto const& page : Xy->GetBatches(BatchParam{0, 17})) { + auto from_sparse_page = page.Impl(); + ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense); + ASSERT_EQ(from_sparse_page->base_rowid, 0); + ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid); + ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows); + ASSERT_EQ(from_sparse_page->gidx_buffer.Size(), from_ghist->gidx_buffer.Size()); + auto const& h_gidx_from_sparse = from_sparse_page->gidx_buffer.HostVector(); + auto const& h_gidx_from_ghist = from_ghist->gidx_buffer.HostVector(); + ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols()); + common::CompressedIterator from_ghist_it(h_gidx_from_ghist.data(), + from_ghist->NumSymbols()); + common::CompressedIterator from_sparse_it(h_gidx_from_sparse.data(), + from_sparse_page->NumSymbols()); + for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) { + EXPECT_EQ(from_ghist_it[i], from_sparse_it[i]); + } + } + }; + + for (auto s : {0.0, 0.2, 0.4, 0.8}) { + test(s); + } +} } // namespace xgboost diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index f91cf6bd9aa6..dee60392057a 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -31,6 +31,34 @@ def test_dmatrix_cupy_init(self) -> None: data = cp.random.randn(5, 5) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) + @pytest.mark.skipif(**tm.no_cupy()) + def test_from_host(self) -> None: + import cupy as cp + n_samples = 64 + n_features = 3 + X, y, w = tm.make_batches( + n_samples, n_features=n_features, n_batches=1, use_cupy=False + ) + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) + booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + + X[0] = cp.array(X[0]) + y[0] = cp.array(y[0]) + w[0] = cp.array(w[0]) + + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) + booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + cp.testing.assert_allclose( + booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0]) + ) + + with pytest.raises(ValueError, match="not initialized with CPU"): + # Training on CPU with GPU data is not supported. + xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4) + + with pytest.raises(ValueError, match=r"Only.*hist.*"): + xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4) + @pytest.mark.skipif(**tm.no_cupy()) def test_metainfo(self) -> None: import cupy as cp From 8f9e851b6d7aa052fdc6c4afaef245cc835d36d9 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 3 Aug 2022 18:17:08 +0800 Subject: [PATCH 02/10] Make functions private. --- src/data/iterative_dmatrix.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 49c1258ddb01..9dcb19f6e6a3 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -73,7 +73,6 @@ class IterativeDMatrix : public DMatrix { return BatchSet(BatchIterator(nullptr)); } - public: void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr ref); void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr ref); From 921967d5c36847ed71c8a68ad930c4b45e58985b Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 3 Aug 2022 18:20:14 +0800 Subject: [PATCH 03/10] note. --- src/data/iterative_dmatrix.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 9dcb19f6e6a3..1f49631e68a8 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -43,8 +43,9 @@ namespace data { * we can obtain the GPU format from CPU but not the other way around since we can't * recover the CSC from Ellpack. More concretely, if users want to construct a CPU * version of `QuantileDMatrix`, input data must be on CPU. However, if users want to - * have a GPU version of `QuantileDMatrix`, data can be anywhere. We can fix this by - * retaining the feature index information in ellpack if there are feature requests. + * have a GPU version of `QuantileDMatrix`, data can be on either place. We can fix this + * by retaining the feature index information in ellpack if there are feature + * requests. Or by retaining the callback and run sketching again. */ class IterativeDMatrix : public DMatrix { MetaInfo info_; From 69b3dd47a1f46bcefe2cc6d288ba54e1f4c8d228 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 3 Aug 2022 18:21:25 +0800 Subject: [PATCH 04/10] note. --- src/data/iterative_dmatrix.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 1f49631e68a8..bcb98093be91 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -45,7 +45,7 @@ namespace data { * version of `QuantileDMatrix`, input data must be on CPU. However, if users want to * have a GPU version of `QuantileDMatrix`, data can be on either place. We can fix this * by retaining the feature index information in ellpack if there are feature - * requests. Or by retaining the callback and run sketching again. + * requests. */ class IterativeDMatrix : public DMatrix { MetaInfo info_; From 3a7bf80d3c13186b346ccbb596f098718d5413f0 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 3 Aug 2022 18:23:37 +0800 Subject: [PATCH 05/10] Cleanup. --- src/data/iterative_dmatrix.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index bcb98093be91..976b87d56ee8 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -97,8 +97,9 @@ class IterativeDMatrix : public DMatrix { batch_param_ = BatchParam{d, max_bin}; batch_param_.sparse_thresh = 0.2; // default from TrainParam - ctx_.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}}); - if (d == Context::kCpuId) { + ctx_.UpdateAllowUnknown( + Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}}); + if (ctx_.IsCPU()) { this->InitFromCPU(iter_handle, missing, ref); } else { this->InitFromCUDA(iter_handle, missing, ref); From b7735432cf9674eff57d91b6f0784225bc5b2054 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Aug 2022 11:40:31 +0800 Subject: [PATCH 06/10] ub. --- src/data/device_adapter.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index e1c6e98cf65b..4000c18b916a 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -113,7 +113,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { for (auto& json_col : json_columns) { auto column = ArrayInterface<1>(get(json_col)); columns.push_back(column); - num_rows_ = std::max(num_rows_, size_t(column.Shape(0))); + num_rows_ = std::max(num_rows_, column.Shape(0)); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; CHECK_EQ(num_rows_, column.Shape(0)) @@ -138,7 +138,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { CudfAdapterBatch batch_; dh::device_vector> columns_; size_t num_rows_{0}; - int device_idx_; + int32_t device_idx_{-1}; }; class CupyAdapterBatch : public detail::NoMetaInfo { From bd9adb94b31d5b023bcbaadaaf739535f03503c7 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Aug 2022 11:43:05 +0800 Subject: [PATCH 07/10] Constant. --- src/data/device_adapter.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 4000c18b916a..4a635e92d29c 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -108,7 +108,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { } device_idx_ = dh::CudaGetPointerDevice(first_column.data); - CHECK_NE(device_idx_, -1); + CHECK_NE(device_idx_, Context::kCpuId); dh::safe_cuda(cudaSetDevice(device_idx_)); for (auto& json_col : json_columns) { auto column = ArrayInterface<1>(get(json_col)); @@ -138,7 +138,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { CudfAdapterBatch batch_; dh::device_vector> columns_; size_t num_rows_{0}; - int32_t device_idx_{-1}; + int32_t device_idx_{Context::kCpuId}; }; class CupyAdapterBatch : public detail::NoMetaInfo { @@ -173,7 +173,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { return; } device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); - CHECK_NE(device_idx_, -1); + CHECK_NE(device_idx_, Context::kCpuId); } explicit CupyAdapter(std::string cuda_interface_str) : CupyAdapter{StringView{cuda_interface_str}} {} @@ -186,7 +186,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { private: ArrayInterface<2> array_interface_; CupyAdapterBatch batch_; - int32_t device_idx_ {-1}; + int32_t device_idx_ {Context::kCpuId}; }; // Returns maximum row length From 555a9120c15b54aa83cd25c241f73c6913286d48 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Aug 2022 11:59:00 +0800 Subject: [PATCH 08/10] Use gtest instead. --- tests/cpp/data/test_ellpack_page.cu | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index f2bf41981703..dccf85092d7f 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -237,8 +237,10 @@ TEST(EllpackPage, Compact) { } } -TEST(EllpackPage, FromGHistIndex) { - auto test = [&](float sparsity) { +namespace { +class EllpackPageTest : public testing::TestWithParam { + protected: + void Run(float sparsity) { // Only testing with small sample size as the cuts might be different between host and // device. size_t n_samples{128}, n_features{13}; @@ -269,10 +271,10 @@ TEST(EllpackPage, FromGHistIndex) { EXPECT_EQ(from_ghist_it[i], from_sparse_it[i]); } } - }; - - for (auto s : {0.0, 0.2, 0.4, 0.8}) { - test(s); } -} +}; +} // namespace + +TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); } +INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f)); } // namespace xgboost From 43ae4ad9371ff39a069bbf010727877c005d8e66 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Aug 2022 15:07:06 +0800 Subject: [PATCH 09/10] Merge kernels. --- src/data/ellpack_page.cu | 25 +++++++++---------------- src/data/iterative_dmatrix.h | 9 +++------ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 0834bb6fd7a3..cf04ab16e7bf 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -288,7 +288,8 @@ ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) namespace { void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span d_row_ptr, - size_t row_stride, common::CompressedByteT* d_compressed_buffer) { + size_t row_stride, common::CompressedByteT* d_compressed_buffer, + size_t null) { dh::device_vector data(page.index.begin(), page.index.end()); auto d_data = dh::ToSpan(data); @@ -305,9 +306,10 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span auto r_begin = d_row_ptr[ridx]; auto r_end = d_row_ptr[ridx + 1]; - size_t rsize = r_end - r_begin; + size_t r_size = r_end - r_begin; - if (ifeature >= rsize) { + if (ifeature >= r_size) { + writer.AtomicWriteSymbol(d_compressed_buffer, null, idx); return; } @@ -320,15 +322,10 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span using T = decltype(t); auto ptr = reinterpret_cast(d_data.data()); auto bin_idx = ptr[r_begin + ifeature] + offset; - writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, ridx * row_stride + ifeature); + writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx); }); }); } - -void RowCountsFromIndptr(common::Span d_row_ptr, common::Span row_counts) { - dh::LaunchN(row_counts.size(), - [=] XGBOOST_DEVICE(size_t i) { row_counts[i] = d_row_ptr[i + 1] - d_row_ptr[i]; }); -} } // anonymous namespace EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, @@ -344,17 +341,13 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag monitor_.Stop("InitCompressedData"); // copy gidx - auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft); common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer(); dh::device_vector row_ptr(page.row_ptr); auto d_row_ptr = dh::ToSpan(row_ptr); - CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer); - // write null value - dh::device_vector row_counts(page.Size()); - auto row_counts_span = dh::ToSpan(row_counts); - RowCountsFromIndptr(d_row_ptr, row_counts_span); - WriteNullValues(this, ctx->gpu_id, row_counts_span); + auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft); + auto null = accessor.NullValue(); + CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null); } // A functor that copies the data from one EllpackPage to another. diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 976b87d56ee8..06d061382ba8 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -40,12 +40,9 @@ namespace data { * * - The CPU format and the GPU format are different, the former uses a CSR + CSC for * histogram index while the latter uses only Ellpack. This results into a design that - * we can obtain the GPU format from CPU but not the other way around since we can't - * recover the CSC from Ellpack. More concretely, if users want to construct a CPU - * version of `QuantileDMatrix`, input data must be on CPU. However, if users want to - * have a GPU version of `QuantileDMatrix`, data can be on either place. We can fix this - * by retaining the feature index information in ellpack if there are feature - * requests. + * we can obtain the GPU format from CPU but the other way around is not yet + * supported. We can search the bin value from ellpack to recover the feature index when + * we support copying data from GPU to CPU. */ class IterativeDMatrix : public DMatrix { MetaInfo info_; From c34790a6841e2086edadf41485e4711c2e9bfa61 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Aug 2022 15:07:58 +0800 Subject: [PATCH 10/10] rebase error. --- tests/ci_build/lint_python.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index d34a5ab8a6ca..f669cdbb25c4 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -121,7 +121,6 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int: "python-package/xgboost/sklearn.py", "python-package/xgboost/spark", "python-package/xgboost/federated.py", - "python-package/xgboost/spark", # tests "tests/python/test_config.py", "tests/python/test_spark/",