Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal inclusive scan #6234

Merged
merged 10 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 4 additions & 8 deletions cmake/Utils.cmake
Expand Up @@ -144,10 +144,8 @@ function(xgboost_set_cuda_flags target)
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")

if (USE_DEVICE_DEBUG)
if (CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-G;-src-in-ptx>)
endif(CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-G;-src-in-ptx>)
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
else (USE_DEVICE_DEBUG)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>)
Expand All @@ -157,10 +155,8 @@ function(xgboost_set_cuda_flags target)
enable_nvtx(${target})
endif (USE_NVTX)

target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1 -DTHRUST_IGNORE_CUB_VERSION_CHECK=1)
target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/cub/)

if (MSVC)
target_compile_options(${target} PRIVATE
Expand Down
2 changes: 1 addition & 1 deletion cub
Submodule cub updated 137 files
3 changes: 0 additions & 3 deletions python-package/xgboost/core.py
Expand Up @@ -870,9 +870,6 @@ class DeviceQuantileDMatrix(DMatrix):
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.

.. versionadded:: 1.1.0

RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
Known limitation:
The data size (rows * cols) can not exceed 2 ** 31 - 1000
"""

def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
Expand Down
4 changes: 0 additions & 4 deletions python-package/xgboost/dask.py
Expand Up @@ -509,10 +509,6 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
max_bin: Number of bins for histogram construction.


Know issue:
The size of each chunk (rows * cols for a single dask chunk/partition) can
not exceed 2 ** 31 - 1000

'''
def __init__(self, client,
data,
Expand Down
2 changes: 1 addition & 1 deletion src/common/device_helpers.cuh
Expand Up @@ -307,7 +307,7 @@ class MemoryLogger {
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(FATAL) << "Attempting to deallocate " << n << " bytes on device "
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
<< current_device << " that was never allocated ";
}
num_deallocations++;
Expand Down
55 changes: 34 additions & 21 deletions src/data/ellpack_page.cu
Expand Up @@ -161,6 +161,26 @@ struct WriteCompressedEllpackFunctor {
}
};

template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
return b;
}
// Not equal
return b;
}
};

// Change the value type of thrust discard iterator so we can use it with cub
template <typename T>
class TypedDiscard : public thrust::discard_iterator<T> {
public:
using value_type = T; // NOLINT
};

// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
Expand Down Expand Up @@ -201,30 +221,23 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
// We redirect the scan output into this functor to do the actual writing
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
d_compressed_buffer, writer, batch, device_accessor, is_valid);
thrust::discard_iterator<size_t> discard;
TypedDiscard<Tuple> discard;
thrust::transform_output_iterator<
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
out(discard, functor);
dh::XGBCachingDeviceAllocator<char> alloc;
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
// lead to oom error.
// or:
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
// https://github.com/NVIDIA/thrust/issues/1299
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
"cannot exceed the limit of 32-bit integer.";
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
key_value_index_iter + batch.Size(), out,
[=] __device__(Tuple a, Tuple b) {
// Key equal
if (a.get<0>() == b.get<0>()) {
b.get<1>() += a.get<1>();
return b;
}
// Not equal
return b;
});
// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
// So we don't crash on n > 2^31
size_t temp_storage_bytes = 0;
using DispatchScan =
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>;
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
}

void WriteNullValues(EllpackPageImpl* dst, int device_idx,
Expand Down
21 changes: 21 additions & 0 deletions tests/python-gpu/test_large_input.py
@@ -0,0 +1,21 @@
import numpy as np
import xgboost as xgb
import cupy as cp
import time
import pytest


# Test for integer overflow or out of memory exceptions
def test_large_input():
available_bytes, _ = cp.cuda.runtime.memGetInfo()
Copy link
Collaborator

@hcho3 hcho3 Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this test will skip when --use-rmm-pool is set, as cuPy is not configured to use the RMM allocator. To make it work with RMM, we'll need to run

cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)

Do we want to enable this test with RMM?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is necessary for now.

# 15 GB
required_bytes = 1.5e+10
if available_bytes < required_bytes:
pytest.skip("Not enough memory on this device")
n = 1000
m = ((1 << 31) + n - 1) // n
assert (np.log2(m * n) > 31)
X = cp.ones((m, n), dtype=np.float32)
y = cp.ones(m)
dmat = xgb.DeviceQuantileDMatrix(X, y)
xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)