New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch UpdatePosition using cudaMemcpy #7964
Merged
Merged
Changes from 64 commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
2b4cf67
Remove single_precision_histogram
RAMitchell f140ebc
Batch nodes from driver
RAMitchell 80a3e78
Categoricals broken
RAMitchell e1fb702
Refactor categoricals
RAMitchell dc100cf
Refactor categoricals 2
RAMitchell bc74458
Skip copy if no categoricals
RAMitchell c4f8eac
Review comment
RAMitchell 2a53849
Merge branch 'master' of github.com:dmlc/xgboost into categorical
RAMitchell a1cddaa
Revert "Categoricals broken"
RAMitchell 829bda6
Merge branch 'master' of github.com:dmlc/xgboost into fuse
RAMitchell 0bc8745
Merge branch 'categorical' of github.com:RAMitchell/xgboost into fuse
RAMitchell fd0e25e
Lint
RAMitchell 9fab64e
Merge branch 'master' of github.com:dmlc/xgboost into fuse
RAMitchell 56785f3
Revert "Revert "Categoricals broken""
RAMitchell 1dd1a6c
Limit concurrent nodes
RAMitchell 8751d14
Lint
RAMitchell 49809bf
Basic blockwise partitioning
RAMitchell 181d7cf
Working block partition
RAMitchell 666eb9b
Reduction
RAMitchell 66173c7
Some failing tests
RAMitchell ec7fea8
Handle empty candidate
RAMitchell 49c5f90
Cleanup
RAMitchell bd48082
experiments
RAMitchell c3ef1f6
Improvements
RAMitchell ba8bbdf
Fused scan
RAMitchell f4ef4ca
Register blocking
RAMitchell 9c27dd0
Cleanup
RAMitchell 0bcc84a
Working tests
RAMitchell 723ff47
Transplanted new code
RAMitchell 199bed9
Optimised
RAMitchell 0e35e99
Do not initialise data structures to maximum possible tree size.
RAMitchell daa9b56
Comments, cleanup
RAMitchell 8ab989e
Refactor FinalizePosition
RAMitchell d50ec4b
Remove redundant functions
RAMitchell c34c3ad
Lint
RAMitchell e534edc
Merge branch 'master' of github.com:dmlc/xgboost into batch-position-…
RAMitchell 47bfc6e
Remove old kernel
RAMitchell a53ba87
Add tests for AtomicIncrement
RAMitchell 7450d68
Change lambda to kernel
RAMitchell 6df1259
Smem + lineinfo
RAMitchell 4010942
Use stream
RAMitchell 1b13fe6
Fast global stores
RAMitchell 24fb339
Fast load without shmem
RAMitchell f40fe94
Memcpy version
RAMitchell 7d5d7e7
Remove left counts kernel
RAMitchell 77f8550
Unstable partition
RAMitchell 14d8663
Warp aggregates
RAMitchell ec968f7
Cleanup
RAMitchell a764986
Use pointer for shared memory
RAMitchell 001c2f2
Row partitioner grid
RAMitchell 70bad86
Custom FinalizePositionKernel
RAMitchell 31e02f0
Revert "Custom FinalizePositionKernel"
RAMitchell b86cb29
Reduce grid size
RAMitchell c3944af
Tune items/thread
RAMitchell cdd134a
FinalisePosition custom kernel
RAMitchell edabc45
Fixing slow scatter
RAMitchell 43eb83e
Remove unstable
RAMitchell d87e366
Merge branch 'master' of github.com:dmlc/xgboost into batch-position-…
RAMitchell 968bb29
Format
RAMitchell 1372ad8
Review comments
RAMitchell a910fb9
Reintroduce prediction caching for external memory.
RAMitchell ff05df5
Avoid initialising temp memory
RAMitchell c3a0e32
Merge branch 'master' of github.com:dmlc/xgboost into batch-position-…
RAMitchell 0280b8c
Lint
RAMitchell 9c642dc
Review comments.
RAMitchell b4f2128
Remove external memory prediction caching.
RAMitchell 8caed98
Merge branch 'master' of github.com:dmlc/xgboost into batch-position-…
RAMitchell 776ef9f
Remove constant memory in favour of __ldg().
RAMitchell 33fea3d
Clang tidy
RAMitchell 9de0692
Clang tidy
RAMitchell 3cd5e41
Review comments.
RAMitchell 9eddfce
Initialise memory in case zero training rows.
RAMitchell File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,174 +1,49 @@ | ||
/*! | ||
* Copyright 2017-2021 XGBoost contributors | ||
* Copyright 2017-2022 XGBoost contributors | ||
*/ | ||
#include <thrust/iterator/discard_iterator.h> | ||
#include <thrust/iterator/transform_output_iterator.h> | ||
#include <thrust/sequence.h> | ||
|
||
#include <vector> | ||
|
||
#include "../../common/device_helpers.cuh" | ||
#include "row_partitioner.cuh" | ||
|
||
namespace xgboost { | ||
namespace tree { | ||
struct IndexFlagTuple { | ||
size_t idx; | ||
size_t flag; | ||
}; | ||
|
||
struct IndexFlagOp { | ||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a, | ||
const IndexFlagTuple& b) const { | ||
return {b.idx, a.flag + b.flag}; | ||
} | ||
}; | ||
|
||
struct WriteResultsFunctor { | ||
bst_node_t left_nidx; | ||
common::Span<bst_node_t> position_in; | ||
common::Span<bst_node_t> position_out; | ||
common::Span<RowPartitioner::RowIndexT> ridx_in; | ||
common::Span<RowPartitioner::RowIndexT> ridx_out; | ||
int64_t* d_left_count; | ||
|
||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { | ||
// the ex_scan_result represents how many rows have been assigned to left | ||
// node so far during scan. | ||
int scatter_address; | ||
if (position_in[x.idx] == left_nidx) { | ||
scatter_address = x.flag - 1; // -1 because inclusive scan | ||
} else { | ||
// current number of rows belong to right node + total number of rows | ||
// belong to left node | ||
scatter_address = (x.idx - x.flag) + *d_left_count; | ||
} | ||
// copy the node id to output | ||
position_out[scatter_address] = position_in[x.idx]; | ||
ridx_out[scatter_address] = ridx_in[x.idx]; | ||
|
||
// Discard | ||
return {}; | ||
} | ||
}; | ||
|
||
// Implement partitioning via single scan operation using transform output to | ||
// write the result | ||
void RowPartitioner::SortPosition(common::Span<bst_node_t> position, | ||
common::Span<bst_node_t> position_out, | ||
common::Span<RowIndexT> ridx, | ||
common::Span<RowIndexT> ridx_out, | ||
bst_node_t left_nidx, bst_node_t, | ||
int64_t* d_left_count, cudaStream_t stream) { | ||
WriteResultsFunctor write_results{left_nidx, position, position_out, | ||
ridx, ridx_out, d_left_count}; | ||
auto discard_write_iterator = | ||
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results); | ||
auto counting = thrust::make_counting_iterator(0llu); | ||
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>( | ||
counting, [=] __device__(size_t idx) { | ||
return IndexFlagTuple{idx, static_cast<size_t>(position[idx] == left_nidx)}; | ||
}); | ||
size_t temp_bytes = 0; | ||
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, | ||
discard_write_iterator, IndexFlagOp(), | ||
position.size(), stream); | ||
dh::TemporaryArray<int8_t> temp(temp_bytes); | ||
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator, | ||
discard_write_iterator, IndexFlagOp(), | ||
position.size(), stream); | ||
} | ||
|
||
void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx, | ||
common::Span<bst_node_t> position) { | ||
dh::safe_cuda(cudaSetDevice(device_idx)); | ||
CHECK_EQ(ridx.size(), position.size()); | ||
dh::LaunchN(ridx.size(), [=] __device__(size_t idx) { | ||
ridx[idx] = idx; | ||
position[idx] = 0; | ||
}); | ||
} | ||
|
||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) | ||
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows), | ||
ridx_b_(num_rows), position_b_(num_rows) { | ||
: device_idx_(device_idx), | ||
ridx_(num_rows), | ||
ridx_tmp_(num_rows), | ||
d_counts(kMaxUpdatePositionBatchSize) { | ||
dh::safe_cuda(cudaSetDevice(device_idx_)); | ||
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_}; | ||
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_}; | ||
ridx_segments_.emplace_back(static_cast<size_t>(0), num_rows); | ||
|
||
Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan()); | ||
left_counts_.resize(256); | ||
thrust::fill(left_counts_.begin(), left_counts_.end(), 0); | ||
streams_.resize(2); | ||
for (auto& stream : streams_) { | ||
dh::safe_cuda(cudaStreamCreate(&stream)); | ||
} | ||
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); | ||
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); | ||
dh::safe_cuda(cudaStreamCreate(&stream_)); | ||
RAMitchell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
RowPartitioner::~RowPartitioner() { | ||
dh::safe_cuda(cudaSetDevice(device_idx_)); | ||
for (auto& stream : streams_) { | ||
dh::safe_cuda(cudaStreamDestroy(stream)); | ||
} | ||
dh::safe_cuda(cudaStreamDestroy(stream_)); | ||
} | ||
|
||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows( | ||
bst_node_t nidx) { | ||
auto segment = ridx_segments_.at(nidx); | ||
// Return empty span here as a valid result | ||
// Will error if we try to construct a span from a pointer with size 0 | ||
if (segment.Size() == 0) { | ||
return {}; | ||
} | ||
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size()); | ||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) { | ||
auto segment = ridx_segments_.at(nidx).segment; | ||
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size()); | ||
} | ||
|
||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() { | ||
return ridx_.CurrentSpan(); | ||
return dh::ToSpan(ridx_); | ||
} | ||
|
||
common::Span<const bst_node_t> RowPartitioner::GetPosition() { | ||
return position_.CurrentSpan(); | ||
} | ||
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost( | ||
bst_node_t nidx) { | ||
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(bst_node_t nidx) { | ||
auto span = GetRows(nidx); | ||
std::vector<RowIndexT> rows(span.size()); | ||
dh::CopyDeviceSpanToVector(&rows, span); | ||
return rows; | ||
} | ||
|
||
std::vector<bst_node_t> RowPartitioner::GetPositionHost() { | ||
auto span = GetPosition(); | ||
std::vector<bst_node_t> position(span.size()); | ||
dh::CopyDeviceSpanToVector(&position, span); | ||
return position; | ||
} | ||
|
||
void RowPartitioner::SortPositionAndCopy(const Segment& segment, | ||
bst_node_t left_nidx, | ||
bst_node_t right_nidx, | ||
int64_t* d_left_count, | ||
cudaStream_t stream) { | ||
SortPosition( | ||
// position_in | ||
common::Span<bst_node_t>(position_.Current() + segment.begin, | ||
segment.Size()), | ||
// position_out | ||
common::Span<bst_node_t>(position_.Other() + segment.begin, | ||
segment.Size()), | ||
// row index in | ||
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()), | ||
// row index out | ||
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()), | ||
left_nidx, right_nidx, d_left_count, stream); | ||
// Copy back key/value | ||
const auto d_position_current = position_.Current() + segment.begin; | ||
const auto d_position_other = position_.Other() + segment.begin; | ||
const auto d_ridx_current = ridx_.Current() + segment.begin; | ||
const auto d_ridx_other = ridx_.Other() + segment.begin; | ||
dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) { | ||
d_position_current[idx] = d_position_other[idx]; | ||
d_ridx_current[idx] = d_ridx_other[idx]; | ||
}); | ||
} | ||
}; // namespace tree | ||
}; // namespace xgboost |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
dh::CUDAStream
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this class is using non-blocking streams with respect to the default stream. Directly swapping it results in a crash - there is the assumption in many places that kernels running on the default stream wait for previous kernels to finish.