Skip to content

Commit

Permalink
Expand categorical node. (#6028)
Browse files Browse the repository at this point in the history

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
trivialfis and hcho3 committed Aug 25, 2020
1 parent 9a4e8b1 commit 20c95be
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 104 deletions.
5 changes: 3 additions & 2 deletions R-package/tests/testthat/test_basic.R
Expand Up @@ -245,11 +245,12 @@ test_that("training continuation works", {
expect_equal(bst$raw, bst2$raw)
expect_equal(dim(bst2$evaluation_log), c(2, 2))
# test continuing from a model in file
xgb.save(bst1, "xgboost.model")
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.model")
xgb.save(bst1, "xgboost.json")
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.json")
if (!windows_flag && !solaris_flag)
expect_equal(bst$raw, bst2$raw)
expect_equal(dim(bst2$evaluation_log), c(2, 2))
file.remove("xgboost.json")
})

test_that("model serialization works", {
Expand Down
18 changes: 9 additions & 9 deletions R-package/tests/testthat/test_callbacks.R
Expand Up @@ -173,16 +173,16 @@ test_that("cb.reset.parameters works as expected", {
})

test_that("cb.save.model works as expected", {
files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model')
files <- c('xgboost_01.json', 'xgboost_02.json', 'xgboost.json')
for (f in files) if (file.exists(f)) file.remove(f)

bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
save_period = 1, save_name = "xgboost_%02d.model")
expect_true(file.exists('xgboost_01.model'))
expect_true(file.exists('xgboost_02.model'))
b1 <- xgb.load('xgboost_01.model')
save_period = 1, save_name = "xgboost_%02d.json")
expect_true(file.exists('xgboost_01.json'))
expect_true(file.exists('xgboost_02.json'))
b1 <- xgb.load('xgboost_01.json')
expect_equal(xgb.ntree(b1), 1)
b2 <- xgb.load('xgboost_02.model')
b2 <- xgb.load('xgboost_02.json')
expect_equal(xgb.ntree(b2), 2)

xgb.config(b2) <- xgb.config(bst)
Expand All @@ -191,9 +191,9 @@ test_that("cb.save.model works as expected", {

# save_period = 0 saves the last iteration's model
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
save_period = 0)
expect_true(file.exists('xgboost.model'))
b2 <- xgb.load('xgboost.model')
save_period = 0, save_name = 'xgboost.json')
expect_true(file.exists('xgboost.json'))
b2 <- xgb.load('xgboost.json')
xgb.config(b2) <- xgb.config(bst)
expect_equal(bst$raw, b2$raw)

Expand Down
3 changes: 2 additions & 1 deletion include/xgboost/base.h
Expand Up @@ -109,7 +109,8 @@ using bst_int = int32_t; // NOLINT
using bst_ulong = uint64_t; // NOLINT
/*! \brief float type, used for storing statistics */
using bst_float = float; // NOLINT

/*! \brief Categorical value type. */
using bst_cat_t = int32_t; // NOLINT
/*! \brief Type for data column (feature) index. */
using bst_feature_t = uint32_t; // NOLINT
/*! \brief Type for data row index.
Expand Down
9 changes: 2 additions & 7 deletions include/xgboost/data.h
Expand Up @@ -35,7 +35,8 @@ enum class DataType : uint8_t {
};

enum class FeatureType : uint8_t {
kNumerical
kNumerical,
kCategorical
};

/*!
Expand Down Expand Up @@ -309,12 +310,6 @@ class SparsePage {
}
}

/*!
* \brief Push row block into the page.
* \param batch the row batch.
*/
void Push(const dmlc::RowBlock<uint32_t>& batch);

/**
* \brief Pushes external data batch onto this page
*
Expand Down
15 changes: 13 additions & 2 deletions include/xgboost/span.h
Expand Up @@ -101,6 +101,18 @@ namespace common {
} while (0);
#endif // __CUDA_ARCH__

#if defined(__CUDA_ARCH__)
#define SPAN_LT(lhs, rhs) \
if (!((lhs) < (rhs))) { \
printf("%lu < %lu failed\n", static_cast<size_t>(lhs), \
static_cast<size_t>(rhs)); \
asm("trap;"); \
}
#else
#define SPAN_LT(lhs, rhs) \
SPAN_CHECK((lhs) < (rhs))
#endif // defined(__CUDA_ARCH__)

namespace detail {
/*!
* By default, XGBoost uses uint32_t for indexing data. int64_t covers all
Expand Down Expand Up @@ -515,7 +527,7 @@ class Span {
}

XGBOOST_DEVICE reference operator[](index_type _idx) const {
SPAN_CHECK(_idx < size());
SPAN_LT(_idx, size());
return data()[_idx];
}

Expand Down Expand Up @@ -575,7 +587,6 @@ class Span {
detail::ExtentValue<Extent, Offset, Count>::value> {
SPAN_CHECK((Count == dynamic_extent) ?
(Offset <= size()) : (Offset + Count <= size()));

return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}

Expand Down
78 changes: 57 additions & 21 deletions include/xgboost/tree_model.h
Expand Up @@ -318,6 +318,8 @@ class RegTree : public Model {
param.num_deleted = 0;
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i ++) {
nodes_[i].SetLeaf(0.0f);
nodes_[i].SetParent(kInvalidNodeId);
Expand Down Expand Up @@ -412,30 +414,33 @@ class RegTree : public Model {
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
* some updaters use the right child index of leaf as a marker
*/
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess, float left_sum,
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
CHECK(node.IsLeaf());
node.SetLeftChild(pleft);
node.SetRightChild(pright);
nodes_[node.LeftChild()].SetParent(nid, true);
nodes_[node.RightChild()].SetParent(nid, false);
node.SetSplit(split_index, split_value,
default_left);

nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);

this->Stat(nid) = {loss_change, sum_hess, base_weight};
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
}
bst_node_t leaf_right_child = kInvalidNodeId);

/**
* \brief Expands a leaf node with categories
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_cat The bitset containing categories
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
* \param left_sum The sum hess of left leaf.
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);

/*!
* \brief get current depth
Expand Down Expand Up @@ -588,6 +593,28 @@ class RegTree : public Model {
* \brief calculate the mean value for each node, required for feature contributions
*/
void FillNodeMeanValues();
/*!
* \brief Get split type for a node.
* \param nidx Index of node.
* \return The type of this split. For leaf node it's always kNumerical.
*/
FeatureType NodeSplitType(bst_node_t nidx) const {
return split_types_.at(nidx);
}
/*!
* \brief Get split types for all nodes.
*/
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }

// The fields of split_categories_segments_[i] are set such that
// the range split_categories_[beg:(beg+size)] stores the bitset for
// the matching categories for the i-th node.
struct Segment {
size_t beg {0};
size_t size {0};
};

private:
// vector of nodes
Expand All @@ -597,9 +624,16 @@ class RegTree : public Model {
// stats of nodes
std::vector<RTreeNodeStat> stats_;
std::vector<bst_float> node_mean_values_;
std::vector<FeatureType> split_types_;

// Categories for each internal node.
std::vector<uint32_t> split_categories_;
// Ptr to split categories of each node.
std::vector<Segment> split_categories_segments_;

// allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize
int AllocNode() {
bst_node_t AllocNode() {
if (param.num_deleted != 0) {
int nid = deleted_nodes_.back();
deleted_nodes_.pop_back();
Expand All @@ -612,6 +646,8 @@ class RegTree : public Model {
<< "number of nodes in the tree exceed 2^31";
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
return nd;
}
// delete a tree node, keep the parent field to allow trace back
Expand Down
67 changes: 23 additions & 44 deletions src/common/bitfield.h
Expand Up @@ -16,6 +16,7 @@
#if defined(__CUDACC__)
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include "device_helpers.cuh"
#endif // defined(__CUDACC__)

#include "xgboost/span.h"
Expand Down Expand Up @@ -54,23 +55,24 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
*
* \tparam Direction Whether the bits start from left or from right.
*/
template <typename VT, typename Direction>
template <typename VT, typename Direction, bool IsConst = false>
struct BitFieldContainer {
using value_type = VT; // NOLINT
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
using pointer = value_type*; // NOLINT

static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.

struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
std::remove_const_t<value_type> int_pos {0};
std::remove_const_t<value_type> bit_pos {0};
};

private:
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");

public:
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
Expand All @@ -92,7 +94,7 @@ struct BitFieldContainer {
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
*/
static size_t ComputeStorageSize(size_t size) {
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
return common::DivRoundUp(size, kValueSize);
}
#if defined(__CUDA_ARCH__)
Expand Down Expand Up @@ -134,19 +136,19 @@ struct BitFieldContainer {
#endif // defined(__CUDA_ARCH__)

#if defined(__CUDA_ARCH__)
__device__ void Set(value_type pos) {
__device__ auto Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
}
#else
void Set(value_type pos) {
Expand All @@ -165,6 +167,7 @@ struct BitFieldContainer {

XGBOOST_DEVICE bool Check(Pos pos_v) const {
pos_v = Direction::Shift(pos_v);
SPAN_LT(pos_v.int_pos, bits_.size());
value_type const value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << pos_v.bit_pos;
value_type result = test_bit & value;
Expand All @@ -179,20 +182,21 @@ struct BitFieldContainer {

XGBOOST_DEVICE pointer Data() const { return bits_.data(); }

friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
inline friend std::ostream &
operator<<(std::ostream &os, BitFieldContainer<VT, Direction, IsConst> field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.bits_[i]);
os << bset << "\n";
}
return os;
}
};

// Bits start from left most bits (most significant bit).
template <typename VT>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
template <typename VT, bool IsConst = false>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type; // NOLINT

Expand All @@ -215,38 +219,13 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
}
};

// Format: <Direction>BitField<size of underlying type in bits>, underlying type must be unsigned.
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
// must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;

#if defined(__CUDACC__)

template <typename V, typename D>
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
std::cout << "Bits: " << name << std::endl;
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitFieldContainer<V, D> h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}

inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}

#endif // defined(__CUDACC__)
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
} // namespace xgboost

#endif // XGBOOST_COMMON_BITFIELD_H_

0 comments on commit 20c95be

Please sign in to comment.