diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 3e671af1c9ab..3d0973c745e2 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -144,10 +144,8 @@ function(xgboost_set_cuda_flags target) endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") if (USE_DEVICE_DEBUG) - if (CMAKE_BUILD_TYPE MATCHES "Debug") - target_compile_options(${target} PRIVATE - $<$:-G;-src-in-ptx>) - endif(CMAKE_BUILD_TYPE MATCHES "Debug") + target_compile_options(${target} PRIVATE + $<$,$>:-G;-src-in-ptx>) else (USE_DEVICE_DEBUG) target_compile_options(${target} PRIVATE $<$:-lineinfo>) @@ -157,10 +155,8 @@ function(xgboost_set_cuda_flags target) enable_nvtx(${target}) endif (USE_NVTX) - target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1) - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0) - target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/) - endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0) + target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1 -DTHRUST_IGNORE_CUB_VERSION_CHECK=1) + target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/) if (MSVC) target_compile_options(${target} PRIVATE diff --git a/cub b/cub index c3cceac115c0..af39ee264f46 160000 --- a/cub +++ b/cub @@ -1 +1 @@ -Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304 +Subproject commit af39ee264f4627608072bf54730bf3a862e56875 diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index e834f409b6f2..d9e1363fd6c7 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -870,9 +870,6 @@ class DeviceQuantileDMatrix(DMatrix): You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. .. versionadded:: 1.1.0 - - Known limitation: - The data size (rows * cols) can not exceed 2 ** 31 - 1000 """ def __init__(self, data, label=None, weight=None, # pylint: disable=W0231 diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 7a2221f27cdb..b3a1f770af57 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -509,10 +509,6 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): max_bin: Number of bins for histogram construction. - Know issue: - The size of each chunk (rows * cols for a single dask chunk/partition) can - not exceed 2 ** 31 - 1000 - ''' def __init__(self, client, data, diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index aab7054fc16d..0112d299a28b 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -307,7 +307,7 @@ class MemoryLogger { void RegisterDeallocation(void *ptr, size_t n, int current_device) { auto itr = device_allocations.find(ptr); if (itr == device_allocations.end()) { - LOG(FATAL) << "Attempting to deallocate " << n << " bytes on device " + LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device << " that was never allocated "; } num_deallocations++; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 0632172a2fdb..d560acc6a7cc 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -161,6 +161,26 @@ struct WriteCompressedEllpackFunctor { } }; +template +struct TupleScanOp { + __device__ Tuple operator()(Tuple a, Tuple b) { + // Key equal + if (a.template get<0>() == b.template get<0>()) { + b.template get<1>() += a.template get<1>(); + return b; + } + // Not equal + return b; + } +}; + +// Change the value type of thrust discard iterator so we can use it with cub +template +class TypedDiscard : public thrust::discard_iterator { + public: + using value_type = T; // NOLINT +}; + // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template @@ -201,30 +221,23 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, // We redirect the scan output into this functor to do the actual writing WriteCompressedEllpackFunctor functor( d_compressed_buffer, writer, batch, device_accessor, is_valid); - thrust::discard_iterator discard; + TypedDiscard discard; thrust::transform_output_iterator< WriteCompressedEllpackFunctor, decltype(discard)> out(discard, functor); - dh::XGBCachingDeviceAllocator alloc; - // 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and - // lead to oom error. - // or: - // after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument - // https://github.com/NVIDIA/thrust/issues/1299 - CHECK_LE(batch.Size(), std::numeric_limits::max() - 1000) - << "Known limitation, size (rows * cols) of quantile based DMatrix " - "cannot exceed the limit of 32-bit integer."; - thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter, - key_value_index_iter + batch.Size(), out, - [=] __device__(Tuple a, Tuple b) { - // Key equal - if (a.get<0>() == b.get<0>()) { - b.get<1>() += a.get<1>(); - return b; - } - // Not equal - return b; - }); + // Go one level down into cub::DeviceScan API to set OffsetT as 64 bit + // So we don't crash on n > 2^31 + size_t temp_storage_bytes = 0; + using DispatchScan = + cub::DispatchScan, cub::NullType, int64_t>; + DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, + TupleScanOp(), cub::NullType(), batch.Size(), + nullptr, false); + dh::TemporaryArray temp_storage(temp_storage_bytes); + DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, + key_value_index_iter, out, TupleScanOp(), + cub::NullType(), batch.Size(), nullptr, false); } void WriteNullValues(EllpackPageImpl* dst, int device_idx, diff --git a/tests/python-gpu/test_large_input.py b/tests/python-gpu/test_large_input.py new file mode 100644 index 000000000000..b99c7a0d9809 --- /dev/null +++ b/tests/python-gpu/test_large_input.py @@ -0,0 +1,21 @@ +import numpy as np +import xgboost as xgb +import cupy as cp +import time +import pytest + + +# Test for integer overflow or out of memory exceptions +def test_large_input(): + available_bytes, _ = cp.cuda.runtime.memGetInfo() + # 15 GB + required_bytes = 1.5e+10 + if available_bytes < required_bytes: + pytest.skip("Not enough memory on this device") + n = 1000 + m = ((1 << 31) + n - 1) // n + assert (np.log2(m * n) > 31) + X = cp.ones((m, n), dtype=np.float32) + y = cp.ones(m) + dmat = xgb.DeviceQuantileDMatrix(X, y) + xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)