From 82d5d4d647f3f97a05d3ab6e45ebad4469a8d765 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 13 Oct 2020 09:40:11 +0800 Subject: [PATCH] Loop over `thrust::reduce`. * Check input chunk size of dqdm. * Add doc for current limitation. --- doc/tutorials/saving_model.rst | 2 +- python-package/xgboost/core.py | 2 ++ python-package/xgboost/dask.py | 4 ++++ src/common/device_helpers.cuh | 17 +++++++++++++++++ src/data/device_adapter.cuh | 2 +- src/data/ellpack_page.cu | 8 ++++++++ src/tree/gpu_hist/histogram.cu | 4 ++-- src/tree/updater_gpu_hist.cu | 5 +++-- tests/cpp/common/test_device_helpers.cu | 11 +++++++++-- tests/cpp/data/test_ellpack_page.cu | 1 - 10 files changed, 47 insertions(+), 9 deletions(-) diff --git a/doc/tutorials/saving_model.rst b/doc/tutorials/saving_model.rst index 544ef4c66a01..d3cd36998bcf 100644 --- a/doc/tutorials/saving_model.rst +++ b/doc/tutorials/saving_model.rst @@ -167,7 +167,7 @@ or in R: Will print out something similiar to (not actual output as it's too long for demonstration): -.. code-block:: json +.. code-block:: javascript { "Learner": { diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ceba043f27ea..e834f409b6f2 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -871,6 +871,8 @@ class DeviceQuantileDMatrix(DMatrix): .. versionadded:: 1.1.0 + Known limitation: + The data size (rows * cols) can not exceed 2 ** 31 - 1000 """ def __init__(self, data, label=None, weight=None, # pylint: disable=W0231 diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6286e2c21172..6b793893ba36 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -509,6 +509,10 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): max_bin: Number of bins for histogram construction. + Know issue: + The size of each chunk (rows * cols for a single dask chunk/partition) can + not exceed 2 ** 31 - 1000 + ''' def __init__(self, client, data, diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index f667075525c3..be1d81dd0027 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1107,4 +1107,21 @@ size_t SegmentedUnique(Inputs &&...inputs) { dh::XGBCachingDeviceAllocator alloc; return SegmentedUnique(thrust::cuda::par(alloc), std::forward(inputs)...); } + +template +auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) { + size_t constexpr kLimit = std::numeric_limits::max() / 2; + size_t size = std::distance(first, second); + using Ty = std::remove_cv_t; + Ty aggregate = init; + for (size_t offset = 0; offset < size; offset += kLimit) { + auto begin_it = first + offset; + auto end_it = first + std::min(offset + kLimit, size); + size_t batch_size = std::distance(begin_it, end_it); + CHECK_LE(batch_size, size); + auto ret = thrust::reduce(policy, begin_it, end_it, init, reduce_op); + aggregate = reduce_op(aggregate, ret); + } + return aggregate; +} } // namespace dh diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 709368f5c756..5ebbac408c63 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -221,7 +221,7 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, } }); dh::XGBCachingDeviceAllocator alloc; - size_t row_stride = thrust::reduce( + size_t row_stride = dh::Reduce( thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0), thrust::maximum()); diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 6ce18340f771..0632172a2fdb 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -206,6 +206,14 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, WriteCompressedEllpackFunctor, decltype(discard)> out(discard, functor); dh::XGBCachingDeviceAllocator alloc; + // 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and + // lead to oom error. + // or: + // after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument + // https://github.com/NVIDIA/thrust/issues/1299 + CHECK_LE(batch.Size(), std::numeric_limits::max() - 1000) + << "Known limitation, size (rows * cols) of quantile based DMatrix " + "cannot exceed the limit of 32-bit integer."; thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter, key_value_index_iter + batch.Size(), out, [=] __device__(Tuple a, Tuple b) { diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index adf973817bab..1c03034eaaf1 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -53,7 +53,7 @@ struct Pair { GradientPair first; GradientPair second; }; -XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { +__host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { return {lhs.first + rhs.first, lhs.second + rhs.second}; } } // anonymous namespace @@ -86,7 +86,7 @@ GradientSumT CreateRoundingFactor(common::Span gpair) { thrust::device_ptr gpair_end {gpair.data() + gpair.size()}; auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); auto end = thrust::make_transform_iterator(gpair_end, Clip()); - Pair p = thrust::reduce(thrust::cuda::par(alloc), beg, end, Pair{}); + Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, end, Pair{}, thrust::plus{}); GradientPair positive_sum {p.first}, negative_sum {p.second}; auto histogram_rounding = GradientSumT { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 394dfd5d5c1a..14eec53a676e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -642,10 +642,11 @@ struct GPUHistMakerDevice { ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; - GradientPair root_sum = thrust::reduce( + GradientPair root_sum = dh::Reduce( thrust::cuda::par(alloc), thrust::device_ptr(gpair.data()), - thrust::device_ptr(gpair.data() + gpair.size())); + thrust::device_ptr(gpair.data() + gpair.size()), + GradientPair{}, thrust::plus{}); rabit::Allreduce(reinterpret_cast(&root_sum), 2); diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 006c036d3479..881b6e9eed8c 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,4 +1,3 @@ - /*! * Copyright 2017 XGBoost contributors */ @@ -122,6 +121,14 @@ void TestSegmentedUniqueRegression(std::vector values, size_t n_dup ASSERT_EQ(segments.at(1), d_segments_out[1] + n_duplicated); } +TEST(DeviceHelpers, Reduce) { + size_t kSize = std::numeric_limits::max(); + auto it = thrust::make_counting_iterator(0ul); + dh::XGBCachingDeviceAllocator alloc; + auto batched = dh::Reduce(thrust::cuda::par(alloc), it, it + kSize, 0ul, thrust::maximum{}); + CHECK_EQ(batched, kSize - 1); +} + TEST(SegmentedUnique, Regression) { { @@ -157,4 +164,4 @@ TEST(SegmentedUnique, Regression) { } } } // namespace common -} // namespace xgboost \ No newline at end of file +} // namespace xgboost diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 23d566068619..2ea89331a2e3 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -234,5 +234,4 @@ TEST(EllpackPage, Compact) { } } } - } // namespace xgboost