Skip to content

Commit

Permalink
[BP] Fix index type for bitfield. (#7541) (#7560)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 13, 2022
1 parent 1311a20 commit 35dac8a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/common/bitfield.h
Expand Up @@ -58,28 +58,29 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer {
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using pointer = value_type*; // NOLINT
using index_type = size_t; // NOLINT
using pointer = value_type*; // NOLINT

static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.
static index_type constexpr kValueSize = sizeof(value_type) * 8;
static index_type constexpr kOne = 1; // force correct type.

struct Pos {
std::remove_const_t<value_type> int_pos {0};
std::remove_const_t<value_type> bit_pos {0};
index_type int_pos{0};
index_type bit_pos{0};
};

private:
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");

public:
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}

Expand All @@ -96,7 +97,7 @@ struct BitFieldContainer {
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
*/
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
XGBOOST_DEVICE static size_t ComputeStorageSize(index_type size) {
return common::DivRoundUp(size, kValueSize);
}
#if defined(__CUDA_ARCH__)
Expand Down Expand Up @@ -138,28 +139,28 @@ struct BitFieldContainer {
#endif // defined(__CUDA_ARCH__)

#if defined(__CUDA_ARCH__)
__device__ auto Set(value_type pos) {
__device__ auto Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
__device__ void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
}
#else
void Set(value_type pos) {
void Set(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit;
}
void Clear(value_type pos) {
void Clear(index_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
Expand All @@ -175,7 +176,7 @@ struct BitFieldContainer {
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
XGBOOST_DEVICE bool Check(index_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/common/test_bitfield.cc
Expand Up @@ -38,6 +38,14 @@ TEST(BitField, Check) {
ASSERT_FALSE(bits.Check(i));
}
}

{
// regression test for correct index type.
std::vector<RBitField8::value_type> storage(33, 0);
storage[32] = static_cast<uint8_t>(1);
auto bits = RBitField8({storage.data(), storage.size()});
ASSERT_TRUE(bits.Check(256));
}
}

template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
Expand Down

0 comments on commit 35dac8a

Please sign in to comment.