diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 57e602114453..01dfe548b181 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -13,6 +13,7 @@ #include #include #include +#include // std::move #include #include "../data/gradient_index.h" @@ -32,101 +33,96 @@ enum ColumnType : uint8_t { kDenseColumn, kSparseColumn }; template class Column { public: - static constexpr int32_t kMissingId = -1; - - Column(ColumnType type, common::Span index, const bst_bin_t index_base) - : type_(type), index_(index), index_base_{index_base} {} + static constexpr bst_bin_t kMissingId = -1; + Column(common::Span index, bst_bin_t least_bin_idx) + : index_(index), index_base_(least_bin_idx) {} virtual ~Column() = default; - uint32_t GetGlobalBinIdx(size_t idx) const { - return index_base_ + static_cast(index_[idx]); + bst_bin_t GetGlobalBinIdx(size_t idx) const { + return index_base_ + static_cast(index_[idx]); } - BinIdxType GetFeatureBinIdx(size_t idx) const { return index_[idx]; } - - uint32_t GetBaseIdx() const { return index_base_; } - - common::Span GetFeatureBinIdxPtr() const { return index_; } - - ColumnType GetType() const { return type_; } - /* returns number of elements in column */ size_t Size() const { return index_.size(); } private: - /* type of column */ - ColumnType type_; /* bin indexes in range [0, max_bins - 1] */ common::Span index_; /* bin index offset for specific feature */ bst_bin_t const index_base_; }; -template -class SparseColumn : public Column { - public: - SparseColumn(ColumnType type, common::Span index, bst_bin_t index_base, - common::Span row_ind) - : Column(type, index, index_base), row_ind_(row_ind) {} +template +class SparseColumnIter : public Column { + private: + using Base = Column; + /* indexes of rows */ + common::Span row_ind_; + size_t idx_; - const size_t* GetRowData() const { return row_ind_.data(); } + size_t const* RowIndices() const { return row_ind_.data(); } - bst_bin_t GetBinIdx(size_t rid, size_t* state) const { + public: + SparseColumnIter(common::Span index, bst_bin_t least_bin_idx, + common::Span row_ind, bst_row_t first_row_idx) + : Base{index, least_bin_idx}, row_ind_(row_ind) { + // first_row_id is the first row in the leaf partition + const size_t* row_data = RowIndices(); const size_t column_size = this->Size(); - if (!((*state) < column_size)) { + // search first nonzero row with index >= rid_span.front() + // note that the input row partition is always sorted. + const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_idx); + // column_size if all missing + idx_ = p - row_data; + } + SparseColumnIter(SparseColumnIter const&) = delete; + SparseColumnIter(SparseColumnIter&&) = default; + + size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; } + bst_bin_t operator[](size_t rid) { + const size_t column_size = this->Size(); + if (!((idx_) < column_size)) { return this->kMissingId; } - while ((*state) < column_size && GetRowIdx(*state) < rid) { - ++(*state); + // find next non-missing row + while ((idx_) < column_size && GetRowIdx(idx_) < rid) { + ++(idx_); } - if (((*state) < column_size) && GetRowIdx(*state) == rid) { - return this->GetGlobalBinIdx(*state); + if (((idx_) < column_size) && GetRowIdx(idx_) == rid) { + // non-missing row found + return this->GetGlobalBinIdx(idx_); } else { + // at the end of column return this->kMissingId; } } +}; - size_t GetInitialState(const size_t first_row_id) const { - const size_t* row_data = GetRowData(); - const size_t column_size = this->Size(); - // search first nonzero row with index >= rid_span.front() - const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id); - // column_size if all messing - return p - row_data; - } - - size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; } - +template +class DenseColumnIter : public Column { private: - /* indexes of rows */ - common::Span row_ind_; -}; + using Base = Column; + /* flags for missing values in dense columns */ + std::vector const& missing_flags_; + size_t feature_offset_; -template -class DenseColumn : public Column { public: - DenseColumn(ColumnType type, common::Span index, uint32_t index_base, - const std::vector& missing_flags, size_t feature_offset) - : Column(type, index, index_base), - missing_flags_(missing_flags), - feature_offset_(feature_offset) {} - bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; } - - int32_t GetBinIdx(size_t idx, size_t* state) const { + explicit DenseColumnIter(common::Span index, bst_bin_t index_base, + std::vector const& missing_flags, size_t feature_offset) + : Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {} + DenseColumnIter(DenseColumnIter const&) = delete; + DenseColumnIter(DenseColumnIter&&) = default; + + bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; } + + bst_bin_t operator[](size_t ridx) const { if (any_missing) { - return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx); + return IsMissing(ridx) ? this->kMissingId : this->GetGlobalBinIdx(ridx); } else { - return this->GetGlobalBinIdx(idx); + return this->GetGlobalBinIdx(ridx); } } - - size_t GetInitialState(const size_t first_row_id) const { return 0; } - - private: - /* flags for missing values in dense columns */ - const std::vector& missing_flags_; - size_t feature_offset_; }; /*! \brief a collection of columns, with support for construction from @@ -234,27 +230,26 @@ class ColumnMatrix { } } - /* Fetch an individual column. This code should be used with type swith - to determine type of bin id's */ - template - std::unique_ptr > GetColumn(unsigned fid) const { - CHECK_EQ(sizeof(BinIdxType), bins_type_size_); + template + auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const { + const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature + const size_t column_size = feature_offsets_[fidx + 1] - feature_offset; + common::Span bin_index = { + reinterpret_cast(&index_[feature_offset * bins_type_size_]), + column_size}; + return SparseColumnIter(bin_index, index_base_[fidx], + {&row_ind_[feature_offset], column_size}, first_row_idx); + } - const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature - const size_t column_size = feature_offsets_[fid + 1] - feature_offset; + template + auto DenseColumn(bst_feature_t fidx) const { + const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature + const size_t column_size = feature_offsets_[fidx + 1] - feature_offset; common::Span bin_index = { reinterpret_cast(&index_[feature_offset * bins_type_size_]), column_size}; - std::unique_ptr > res; - if (type_[fid] == ColumnType::kDenseColumn) { - CHECK_EQ(any_missing, any_missing_); - res.reset(new DenseColumn(type_[fid], bin_index, index_base_[fid], - missing_flags_, feature_offset)); - } else { - res.reset(new SparseColumn(type_[fid], bin_index, index_base_[fid], - {&row_ind_[feature_offset], column_size})); - } - return res; + return std::move(DenseColumnIter{ + bin_index, static_cast(index_base_[fidx]), missing_flags_, feature_offset}); } template @@ -342,6 +337,7 @@ class ColumnMatrix { } BinTypeSize GetTypeSize() const { return bins_type_size_; } + auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } // This is just an utility function bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) { diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 648cbe61a3a3..7d065e1c4ff0 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -52,23 +52,23 @@ class PartitionBuilder { // Handle dense columns // Analog of std::stable_partition, but in no-inplace manner template - inline std::pair PartitionKernel(const ColumnType& column, + inline std::pair PartitionKernel(ColumnType* p_column, common::Span row_indices, common::Span left_part, common::Span right_part, size_t base_rowid, Predicate&& pred) { + auto& column = *p_column; size_t* p_left_part = left_part.data(); size_t* p_right_part = right_part.data(); size_t nleft_elems = 0; size_t nright_elems = 0; - auto state = column.GetInitialState(row_indices.front() - base_rowid); auto p_row_indices = row_indices.data(); auto n_samples = row_indices.size(); for (size_t i = 0; i < n_samples; ++i) { auto rid = p_row_indices[i]; - const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state); + const int32_t bin_id = column[rid - base_rowid]; if (any_missing && bin_id == ColumnType::kMissingId) { if (default_left) { p_left_part[nleft_elems++] = rid; @@ -115,8 +115,6 @@ class PartitionBuilder { common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); const bst_uint fid = tree[nid].SplitIndex(); const bool default_left = tree[nid].DefaultLeft(); - const auto column_ptr = column_matrix.GetColumn(fid); - bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; auto node_cats = tree.NodeCats(nid); @@ -146,25 +144,23 @@ class PartitionBuilder { }; std::pair child_nodes_sizes; - if (column_ptr->GetType() == xgboost::common::kDenseColumn) { - const common::DenseColumn& column = - static_cast& >(*(column_ptr.get())); + if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { + auto column = column_matrix.DenseColumn(fid); if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, left, right, + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, gmat.base_rowid, pred); } else { - child_nodes_sizes = PartitionKernel(column, rid_span, left, right, + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, gmat.base_rowid, pred); } } else { CHECK_EQ(any_missing, true); - const common::SparseColumn& column - = static_cast& >(*(column_ptr.get())); + auto column = column_matrix.SparseColumn(fid, rid_span.front() - gmat.base_rowid); if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, left, right, + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, gmat.base_rowid, pred); } else { - child_nodes_sizes = PartitionKernel(column, rid_span, left, right, + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, gmat.base_rowid, pred); } } diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 2626b6fb3c1d..363a22176cdf 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -27,38 +27,33 @@ TEST(DenseColumn, Test) { for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { - switch (column_matrix.GetTypeSize()) { - case kUint8BinsTypeSize: { - auto col = column_matrix.GetColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], - (*col.get()).GetGlobalBinIdx(i)); - } - break; - case kUint16BinsTypeSize: { - auto col = column_matrix.GetColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], - (*col.get()).GetGlobalBinIdx(i)); - } - break; - case kUint32BinsTypeSize: { - auto col = column_matrix.GetColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], - (*col.get()).GetGlobalBinIdx(i)); - } - break; + switch (column_matrix.GetTypeSize()) { + case kUint8BinsTypeSize: { + auto col = column_matrix.DenseColumn(j); + ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); + } break; + case kUint16BinsTypeSize: { + auto col = column_matrix.DenseColumn(j); + ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); + } break; + case kUint32BinsTypeSize: { + auto col = column_matrix.DenseColumn(j); + ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); + } break; } } } } } -template -inline void CheckSparseColumn(const Column& col_input, const GHistIndexMatrix& gmat) { - const SparseColumn& col = static_cast& >(col_input); +template +inline void CheckSparseColumn(const SparseColumnIter& col_input, + const GHistIndexMatrix& gmat) { + const SparseColumnIter& col = + static_cast&>(col_input); ASSERT_EQ(col.Size(), gmat.index.Size()); for (auto i = 0ull; i < col.Size(); i++) { - ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], - col.GetGlobalBinIdx(i)); + ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i)); } } @@ -75,32 +70,27 @@ TEST(SparseColumn, Test) { } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckSparseColumn(*col.get(), gmat); - } - break; + auto col = column_matrix.SparseColumn(0, 0); + CheckSparseColumn(col, gmat); + } break; case kUint16BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckSparseColumn(*col.get(), gmat); - } - break; + auto col = column_matrix.SparseColumn(0, 0); + CheckSparseColumn(col, gmat); + } break; case kUint32BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckSparseColumn(*col.get(), gmat); - } - break; + auto col = column_matrix.SparseColumn(0, 0); + CheckSparseColumn(col, gmat); + } break; } } } -template -inline void CheckColumWithMissingValue(const Column& col_input, +template +inline void CheckColumWithMissingValue(const DenseColumnIter& col, const GHistIndexMatrix& gmat) { - const DenseColumn& col = static_cast& >(col_input); for (auto i = 0ull; i < col.Size(); i++) { if (col.IsMissing(i)) continue; - EXPECT_EQ(gmat.index[gmat.row_ptr[i]], - col.GetGlobalBinIdx(i)); + EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i)); } } @@ -117,20 +107,17 @@ TEST(DenseColumnWithMissing, Test) { } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckColumWithMissingValue(*col.get(), gmat); - } - break; + auto col = column_matrix.DenseColumn(0); + CheckColumWithMissingValue(col, gmat); + } break; case kUint16BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckColumWithMissingValue(*col.get(), gmat); - } - break; + auto col = column_matrix.DenseColumn(0); + CheckColumWithMissingValue(col, gmat); + } break; case kUint32BinsTypeSize: { - auto col = column_matrix.GetColumn(0); - CheckColumWithMissingValue(*col.get(), gmat); - } - break; + auto col = column_matrix.DenseColumn(0); + CheckColumWithMissingValue(col, gmat); + } break; } } }