diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 8fac9fca2f14..7dd62b38216e 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -17,6 +17,28 @@ namespace xgboost { namespace common { +namespace cuda { +/** + * copy and paste of the host version, we can't make it a __host__ __device__ function as + * the fn might be a host only or device only callable object, which is not allowed by nvcc. + */ +template +auto __device__ DispatchBinType(BinTypeSize type, Fn&& fn) { + switch (type) { + case kUint8BinsTypeSize: { + return fn(uint8_t{}); + } + case kUint16BinsTypeSize: { + return fn(uint16_t{}); + } + case kUint32BinsTypeSize: { + return fn(uint32_t{}); + } + } + SPAN_CHECK(false); + return fn(uint32_t{}); +} +} // namespace cuda namespace detail { struct EntryCompareOp { diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index e1c6e98cf65b..4a635e92d29c 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -108,12 +108,12 @@ class CudfAdapter : public detail::SingleBatchDataIter { } device_idx_ = dh::CudaGetPointerDevice(first_column.data); - CHECK_NE(device_idx_, -1); + CHECK_NE(device_idx_, Context::kCpuId); dh::safe_cuda(cudaSetDevice(device_idx_)); for (auto& json_col : json_columns) { auto column = ArrayInterface<1>(get(json_col)); columns.push_back(column); - num_rows_ = std::max(num_rows_, size_t(column.Shape(0))); + num_rows_ = std::max(num_rows_, column.Shape(0)); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; CHECK_EQ(num_rows_, column.Shape(0)) @@ -138,7 +138,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { CudfAdapterBatch batch_; dh::device_vector> columns_; size_t num_rows_{0}; - int device_idx_; + int32_t device_idx_{Context::kCpuId}; }; class CupyAdapterBatch : public detail::NoMetaInfo { @@ -173,7 +173,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { return; } device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); - CHECK_NE(device_idx_, -1); + CHECK_NE(device_idx_, Context::kCpuId); } explicit CupyAdapter(std::string cuda_interface_str) : CupyAdapter{StringView{cuda_interface_str}} {} @@ -186,7 +186,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { private: ArrayInterface<2> array_interface_; CupyAdapterBatch batch_; - int32_t device_idx_ {-1}; + int32_t device_idx_ {Context::kCpuId}; }; // Returns maximum row length diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 14a1b2bbf172..cf04ab16e7bf 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -1,14 +1,16 @@ /*! - * Copyright 2019-2020 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ -#include #include #include + #include "../common/categorical.h" #include "../common/hist_util.cuh" #include "../common/random.h" #include "./ellpack_page.cuh" #include "device_adapter.cuh" +#include "gradient_index.h" +#include "xgboost/data.h" namespace xgboost { @@ -32,7 +34,7 @@ __global__ void CompressBinEllpackKernel( const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data const float* __restrict__ cuts, // HistogramCuts::cut_values_ - const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_ + const uint32_t* __restrict__ cut_ptrs, // HistogramCuts::cut_ptrs_ common::Span feature_types, size_t base_row, // batch_row_begin size_t n_rows, @@ -50,8 +52,8 @@ __global__ void CompressBinEllpackKernel( int feature = entry.index; float fvalue = entry.fvalue; // {feature_cuts, ncuts} forms the array of cuts of `feature'. - const float* feature_cuts = &cuts[cut_rows[feature]]; - int ncuts = cut_rows[feature + 1] - cut_rows[feature]; + const float* feature_cuts = &cuts[cut_ptrs[feature]]; + int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature]; bool is_cat = common::IsCat(feature_types, ifeature); // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] @@ -69,7 +71,7 @@ __global__ void CompressBinEllpackKernel( bin = ncuts - 1; } // Add the number of bins in previous features. - bin += cut_rows[feature]; + bin += cut_ptrs[feature]; } // Write to gidx buffer. wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); @@ -284,6 +286,70 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) +namespace { +void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span d_row_ptr, + size_t row_stride, common::CompressedByteT* d_compressed_buffer, + size_t null) { + dh::device_vector data(page.index.begin(), page.index.end()); + auto d_data = dh::ToSpan(data); + + dh::device_vector csc_indptr(page.index.Offset(), + page.index.Offset() + page.index.OffsetSize()); + auto d_csc_indptr = dh::ToSpan(csc_indptr); + + auto bin_type = page.index.GetBinTypeSize(); + common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value + + dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable { + auto ridx = idx / row_stride; + auto ifeature = idx % row_stride; + + auto r_begin = d_row_ptr[ridx]; + auto r_end = d_row_ptr[ridx + 1]; + size_t r_size = r_end - r_begin; + + if (ifeature >= r_size) { + writer.AtomicWriteSymbol(d_compressed_buffer, null, idx); + return; + } + + size_t offset = 0; + if (!d_csc_indptr.empty()) { + // is dense, ifeature is the actual feature index. + offset = d_csc_indptr[ifeature]; + } + common::cuda::DispatchBinType(bin_type, [&](auto t) { + using T = decltype(t); + auto ptr = reinterpret_cast(d_data.data()); + auto bin_idx = ptr[r_begin + ifeature] + offset; + writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx); + }); + }); +} +} // anonymous namespace + +EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, + common::Span ft) + : is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} { + auto it = common::MakeIndexTransformIter( + [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); + row_stride = *std::max_element(it, it + page.Size()); + + CHECK_GE(ctx->gpu_id, 0); + monitor_.Start("InitCompressedData"); + InitCompressedData(ctx->gpu_id); + monitor_.Stop("InitCompressedData"); + + // copy gidx + common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer(); + dh::device_vector row_ptr(page.row_ptr); + auto d_row_ptr = dh::ToSpan(row_ptr); + + auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft); + auto null = accessor.NullValue(); + CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null); +} + // A functor that copies the data from one EllpackPage to another. struct CopyPage { common::CompressedBufferWriter cbw; diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 7a2020c8b0b4..75d394e30e57 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -116,6 +116,8 @@ struct EllpackDeviceAccessor { }; +class GHistIndexMatrix; + class EllpackPageImpl { public: /*! @@ -154,6 +156,11 @@ class EllpackPageImpl { common::Span row_counts_span, common::Span feature_types, size_t row_stride, size_t n_rows, common::HistogramCuts const& cuts); + /** + * \brief Constructor from an existing CPU gradient index. + */ + explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, + common::Span ft); /*! \brief Copy the elements of the given ELLPACK page into this page. * diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 08f03f1a14d3..e34db5495952 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -66,6 +66,7 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts & max_num_bins(max_bin_per_feat), isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} + GHistIndexMatrix::~GHistIndexMatrix() = default; void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 4d8c602842df..e43fcccbc4b2 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -205,7 +205,12 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, BatchSet IterativeDMatrix::GetGradientIndex(BatchParam const& param) { CheckParam(param); - CHECK(ghist_) << "Not initialized with CPU data"; + CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training. +Possible solutions: +- Use `DMatrix` instead. +- Use CPU input for `QuantileDMatrix`. +- Run training on GPU. +)"; auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ghist_)); return BatchSet(begin_iter); diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index b2159e978522..901662852a15 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -168,7 +168,17 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing, BatchSet IterativeDMatrix::GetEllpackBatches(BatchParam const& param) { CheckParam(param); - CHECK(ellpack_) << "Not initialized with GPU data"; + if (!ellpack_ && !ghist_) { + LOG(FATAL) << "`QuantileDMatrix` not initialized."; + } + if (!ellpack_ && ghist_) { + ellpack_.reset(new EllpackPage()); + this->ctx_.gpu_id = param.gpu_id; + this->Info().feature_types.SetDevice(param.gpu_id); + *ellpack_->Impl() = + EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan()); + } + CHECK(ellpack_); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_)); return BatchSet(begin_iter); } diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 37c53e782973..06d061382ba8 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -22,7 +22,28 @@ class HistogramCuts; } namespace data { - +/** + * \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its + * construction process. + * + * `QuantileDMatrix` is an intermediate storage for quantilization results including + * quantile cuts and histogram index. Quantilization is designed to be performed on stream + * of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work + * with batches of data. During initializaion, it will walk through the data multiple + * times iteratively in order to perform quantilization. This design can help us reduce + * memory usage significantly by avoiding data concatenation along with removing the CSR + * matrix `SparsePage`. However, it has its limitation (can be fixed if needed): + * + * - It's only supported by hist tree method (both CPU and GPU) since approx requires a + * re-calculation of quantiles for each iteration. We can fix this by retaining a + * reference to the callback if there are feature requests. + * + * - The CPU format and the GPU format are different, the former uses a CSR + CSC for + * histogram index while the latter uses only Ellpack. This results into a design that + * we can obtain the GPU format from CPU but the other way around is not yet + * supported. We can search the bin value from ellpack to recover the feature index when + * we support copying data from GPU to CPU. + */ class IterativeDMatrix : public DMatrix { MetaInfo info_; Context ctx_; @@ -40,7 +61,8 @@ class IterativeDMatrix : public DMatrix { LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin << " vs. " << batch_param_.max_bin; } - CHECK(!param.regen) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`."; + CHECK(!param.regen && param.hess.empty()) + << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`."; } template @@ -49,7 +71,6 @@ class IterativeDMatrix : public DMatrix { return BatchSet(BatchIterator(nullptr)); } - public: void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr ref); void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr ref); @@ -73,8 +94,9 @@ class IterativeDMatrix : public DMatrix { batch_param_ = BatchParam{d, max_bin}; batch_param_.sparse_thresh = 0.2; // default from TrainParam - ctx_.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}}); - if (d == Context::kCpuId) { + ctx_.UpdateAllowUnknown( + Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}}); + if (ctx_.IsCPU()) { this->InitFromCPU(iter_handle, missing, ref); } else { this->InitFromCUDA(iter_handle, missing, ref); diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index d34a5ab8a6ca..f669cdbb25c4 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -121,7 +121,6 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int: "python-package/xgboost/sklearn.py", "python-package/xgboost/spark", "python-package/xgboost/federated.py", - "python-package/xgboost/spark", # tests "tests/python/test_config.py", "tests/python/test_spark/", diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index a67ab1d59f02..dccf85092d7f 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -236,4 +236,45 @@ TEST(EllpackPage, Compact) { } } } + +namespace { +class EllpackPageTest : public testing::TestWithParam { + protected: + void Run(float sparsity) { + // Only testing with small sample size as the cuts might be different between host and + // device. + size_t n_samples{128}, n_features{13}; + Context ctx; + ctx.gpu_id = 0; + auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true); + std::unique_ptr from_ghist; + ASSERT_TRUE(Xy->SingleColBlock()); + for (auto const& page : Xy->GetBatches(BatchParam{17, 0.6})) { + from_ghist.reset(new EllpackPageImpl{&ctx, page, {}}); + } + + for (auto const& page : Xy->GetBatches(BatchParam{0, 17})) { + auto from_sparse_page = page.Impl(); + ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense); + ASSERT_EQ(from_sparse_page->base_rowid, 0); + ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid); + ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows); + ASSERT_EQ(from_sparse_page->gidx_buffer.Size(), from_ghist->gidx_buffer.Size()); + auto const& h_gidx_from_sparse = from_sparse_page->gidx_buffer.HostVector(); + auto const& h_gidx_from_ghist = from_ghist->gidx_buffer.HostVector(); + ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols()); + common::CompressedIterator from_ghist_it(h_gidx_from_ghist.data(), + from_ghist->NumSymbols()); + common::CompressedIterator from_sparse_it(h_gidx_from_sparse.data(), + from_sparse_page->NumSymbols()); + for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) { + EXPECT_EQ(from_ghist_it[i], from_sparse_it[i]); + } + } + } +}; +} // namespace + +TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); } +INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f)); } // namespace xgboost diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index f91cf6bd9aa6..dee60392057a 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -31,6 +31,34 @@ def test_dmatrix_cupy_init(self) -> None: data = cp.random.randn(5, 5) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) + @pytest.mark.skipif(**tm.no_cupy()) + def test_from_host(self) -> None: + import cupy as cp + n_samples = 64 + n_features = 3 + X, y, w = tm.make_batches( + n_samples, n_features=n_features, n_batches=1, use_cupy=False + ) + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) + booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + + X[0] = cp.array(X[0]) + y[0] = cp.array(y[0]) + w[0] = cp.array(w[0]) + + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) + booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + cp.testing.assert_allclose( + booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0]) + ) + + with pytest.raises(ValueError, match="not initialized with CPU"): + # Training on CPU with GPU data is not supported. + xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4) + + with pytest.raises(ValueError, match=r"Only.*hist.*"): + xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4) + @pytest.mark.skipif(**tm.no_cupy()) def test_metainfo(self) -> None: import cupy as cp