Skip to content

Commit

Permalink
Fix accumulate rows.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 1, 2021
1 parent f18ecd4 commit 2617120
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class MetaInfo {
*
* \param that The other MetaInfo object.
*/
void Extend(MetaInfo const& that, bool check_column);
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);

private:
/*! \brief argsort of labels */
Expand Down
6 changes: 4 additions & 2 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,10 @@ void MetaInfo::GetFeatureInfo(const char *field,
}
}

void MetaInfo::Extend(MetaInfo const& that, bool check_column) {
this->num_row_ += that.num_row_;
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
if (accumulate_rows) {
this->num_row_ += that.num_row_;
}
if (this->num_col_ != 0) {
if (check_column) {
CHECK_EQ(this->num_col_, that.num_col_)
Expand Down
2 changes: 1 addition & 1 deletion src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
proxy->Info().num_row_ = num_rows();
proxy->Info().num_col_ = cols;
if (batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false);
this->info_.Extend(std::move(proxy->Info()), false, true);
}
n_batches_for_verification++;
}
Expand Down
2 changes: 1 addition & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void SparsePageDMatrix::InitializeExternalMemory() {
};

for (auto const &page : this->GetRowBatches()) {
this->info_.Extend(std::move(proxy->Info()), false);
this->info_.Extend(std::move(proxy->Info()), false, false);
n_features = std::max(n_features, num_cols());
n_samples += num_rows();
nnz += page.data.Size();
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/data/test_metainfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ TEST(MetaInfo, HostExtend) {
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());

lhs.Extend(rhs, true);
lhs.Extend(rhs, true, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/data/test_metainfo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ TEST(MetaInfo, DeviceExtend) {
lhs.num_row_ = kRows;
rhs.num_row_ = kRows;

lhs.Extend(rhs, true);
lhs.Extend(rhs, true, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_FALSE(lhs.labels_.HostCanRead());

Expand Down

0 comments on commit 2617120

Please sign in to comment.