Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] Fix index type for bitfield. (#7541) #7560

Merged
merged 1 commit into from Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 15 additions & 14 deletions src/common/bitfield.h
Expand Up @@ -58,28 +58,29 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer {
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using pointer = value_type*; // NOLINT
using index_type = size_t; // NOLINT
using pointer = value_type*; // NOLINT

static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.
static index_type constexpr kValueSize = sizeof(value_type) * 8;
static index_type constexpr kOne = 1; // force correct type.

struct Pos {
std::remove_const_t<value_type> int_pos {0};
std::remove_const_t<value_type> bit_pos {0};
index_type int_pos{0};
index_type bit_pos{0};
};

private:
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");

public:
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}

Expand All @@ -96,7 +97,7 @@ struct BitFieldContainer {
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
*/
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
XGBOOST_DEVICE static size_t ComputeStorageSize(index_type size) {
return common::DivRoundUp(size, kValueSize);
}
#if defined(__CUDA_ARCH__)
Expand Down Expand Up @@ -138,28 +139,28 @@ struct BitFieldContainer {
#endif // defined(__CUDA_ARCH__)

#if defined(__CUDA_ARCH__)
__device__ auto Set(value_type pos) {
__device__ auto Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
__device__ void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
}
#else
void Set(value_type pos) {
void Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit;
}
void Clear(value_type pos) {
void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
Expand All @@ -175,7 +176,7 @@ struct BitFieldContainer {
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
XGBOOST_DEVICE bool Check(index_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/common/test_bitfield.cc
Expand Up @@ -38,6 +38,14 @@ TEST(BitField, Check) {
ASSERT_FALSE(bits.Check(i));
}
}

{
// regression test for correct index type.
std::vector<RBitField8::value_type> storage(33, 0);
storage[32] = static_cast<uint8_t>(1);
auto bits = RBitField8({storage.data(), storage.size()});
ASSERT_TRUE(bits.Check(256));
}
}

template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
Expand Down