Skip to content

Commit

Permalink
Small cleanup to Column. (#7898)
Browse files Browse the repository at this point in the history

* Define forward iterator to hide the internal state.
  • Loading branch information
trivialfis committed May 15, 2022
1 parent ee382c4 commit 1baad86
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 146 deletions.
154 changes: 75 additions & 79 deletions src/common/column_matrix.h
Expand Up @@ -13,6 +13,7 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // std::move
#include <vector>

#include "../data/gradient_index.h"
Expand All @@ -32,101 +33,96 @@ enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
template <typename BinIdxType>
class Column {
public:
static constexpr int32_t kMissingId = -1;

Column(ColumnType type, common::Span<const BinIdxType> index, const bst_bin_t index_base)
: type_(type), index_(index), index_base_{index_base} {}
static constexpr bst_bin_t kMissingId = -1;

Column(common::Span<const BinIdxType> index, bst_bin_t least_bin_idx)
: index_(index), index_base_(least_bin_idx) {}
virtual ~Column() = default;

uint32_t GetGlobalBinIdx(size_t idx) const {
return index_base_ + static_cast<uint32_t>(index_[idx]);
bst_bin_t GetGlobalBinIdx(size_t idx) const {
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
}

BinIdxType GetFeatureBinIdx(size_t idx) const { return index_[idx]; }

uint32_t GetBaseIdx() const { return index_base_; }

common::Span<const BinIdxType> GetFeatureBinIdxPtr() const { return index_; }

ColumnType GetType() const { return type_; }

/* returns number of elements in column */
size_t Size() const { return index_.size(); }

private:
/* type of column */
ColumnType type_;
/* bin indexes in range [0, max_bins - 1] */
common::Span<const BinIdxType> index_;
/* bin index offset for specific feature */
bst_bin_t const index_base_;
};

template <typename BinIdxType>
class SparseColumn : public Column<BinIdxType> {
public:
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, bst_bin_t index_base,
common::Span<const size_t> row_ind)
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
template <typename BinIdxT>
class SparseColumnIter : public Column<BinIdxT> {
private:
using Base = Column<BinIdxT>;
/* indexes of rows */
common::Span<const size_t> row_ind_;
size_t idx_;

const size_t* GetRowData() const { return row_ind_.data(); }
size_t const* RowIndices() const { return row_ind_.data(); }

bst_bin_t GetBinIdx(size_t rid, size_t* state) const {
public:
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
common::Span<const size_t> row_ind, bst_row_t first_row_idx)
: Base{index, least_bin_idx}, row_ind_(row_ind) {
// first_row_id is the first row in the leaf partition
const size_t* row_data = RowIndices();
const size_t column_size = this->Size();
if (!((*state) < column_size)) {
// search first nonzero row with index >= rid_span.front()
// note that the input row partition is always sorted.
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_idx);
// column_size if all missing
idx_ = p - row_data;
}
SparseColumnIter(SparseColumnIter const&) = delete;
SparseColumnIter(SparseColumnIter&&) = default;

size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; }
bst_bin_t operator[](size_t rid) {
const size_t column_size = this->Size();
if (!((idx_) < column_size)) {
return this->kMissingId;
}
while ((*state) < column_size && GetRowIdx(*state) < rid) {
++(*state);
// find next non-missing row
while ((idx_) < column_size && GetRowIdx(idx_) < rid) {
++(idx_);
}
if (((*state) < column_size) && GetRowIdx(*state) == rid) {
return this->GetGlobalBinIdx(*state);
if (((idx_) < column_size) && GetRowIdx(idx_) == rid) {
// non-missing row found
return this->GetGlobalBinIdx(idx_);
} else {
// at the end of column
return this->kMissingId;
}
}
};

size_t GetInitialState(const size_t first_row_id) const {
const size_t* row_data = GetRowData();
const size_t column_size = this->Size();
// search first nonzero row with index >= rid_span.front()
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id);
// column_size if all messing
return p - row_data;
}

size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; }

template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> {
private:
/* indexes of rows */
common::Span<const size_t> row_ind_;
};
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
std::vector<bool> const& missing_flags_;
size_t feature_offset_;

template <typename BinIdxType, bool any_missing>
class DenseColumn : public Column<BinIdxType> {
public:
DenseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
const std::vector<bool>& missing_flags, size_t feature_offset)
: Column<BinIdxType>(type, index, index_base),
missing_flags_(missing_flags),
feature_offset_(feature_offset) {}
bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; }

int32_t GetBinIdx(size_t idx, size_t* state) const {
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
std::vector<bool> const& missing_flags, size_t feature_offset)
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
DenseColumnIter(DenseColumnIter const&) = delete;
DenseColumnIter(DenseColumnIter&&) = default;

bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; }

bst_bin_t operator[](size_t ridx) const {
if (any_missing) {
return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx);
return IsMissing(ridx) ? this->kMissingId : this->GetGlobalBinIdx(ridx);
} else {
return this->GetGlobalBinIdx(idx);
return this->GetGlobalBinIdx(ridx);
}
}

size_t GetInitialState(const size_t first_row_id) const { return 0; }

private:
/* flags for missing values in dense columns */
const std::vector<bool>& missing_flags_;
size_t feature_offset_;
};

/*! \brief a collection of columns, with support for construction from
Expand Down Expand Up @@ -234,27 +230,26 @@ class ColumnMatrix {
}
}

/* Fetch an individual column. This code should be used with type swith
to determine type of bin id's */
template <typename BinIdxType, bool any_missing>
std::unique_ptr<const Column<BinIdxType> > GetColumn(unsigned fid) const {
CHECK_EQ(sizeof(BinIdxType), bins_type_size_);
template <typename BinIdxType>
auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const {
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = {
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size};
return SparseColumnIter<BinIdxType>(bin_index, index_base_[fidx],
{&row_ind_[feature_offset], column_size}, first_row_idx);
}

const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fid + 1] - feature_offset;
template <typename BinIdxType, bool any_missing>
auto DenseColumn(bst_feature_t fidx) const {
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = {
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size};
std::unique_ptr<const Column<BinIdxType> > res;
if (type_[fid] == ColumnType::kDenseColumn) {
CHECK_EQ(any_missing, any_missing_);
res.reset(new DenseColumn<BinIdxType, any_missing>(type_[fid], bin_index, index_base_[fid],
missing_flags_, feature_offset));
} else {
res.reset(new SparseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
{&row_ind_[feature_offset], column_size}));
}
return res;
return std::move(DenseColumnIter<BinIdxType, any_missing>{
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset});
}

template <typename T>
Expand Down Expand Up @@ -342,6 +337,7 @@ class ColumnMatrix {
}

BinTypeSize GetTypeSize() const { return bins_type_size_; }
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }

// This is just an utility function
bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {
Expand Down
24 changes: 10 additions & 14 deletions src/common/partition_builder.h
Expand Up @@ -52,23 +52,23 @@ class PartitionBuilder {
// Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
common::Span<const size_t> row_indices,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) {
auto& column = *p_column;
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
auto state = column.GetInitialState(row_indices.front() - base_rowid);

auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size();

for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i];
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
const int32_t bin_id = column[rid - base_rowid];
if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) {
p_left_part[nleft_elems++] = rid;
Expand Down Expand Up @@ -115,8 +115,6 @@ class PartitionBuilder {
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);

bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);

Expand Down Expand Up @@ -146,25 +144,23 @@ class PartitionBuilder {
};

std::pair<size_t, size_t> child_nodes_sizes;
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
}
} else {
CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
auto column = column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
}
}
Expand Down

0 comments on commit 1baad86

Please sign in to comment.