diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7d5000c48947..3a403d541b56 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -403,6 +403,7 @@ template class BatchIterator { public: using iterator_category = std::forward_iterator_tag; // NOLINT + explicit BatchIterator(BatchIteratorImpl* impl) { impl_.reset(impl); } explicit BatchIterator(std::shared_ptr> impl) { impl_ = impl; } BatchIterator &operator++() { diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index f2dfa498a729..00e502dfa767 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -162,8 +162,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin BatchSet IterativeDeviceDMatrix::GetEllpackBatches(const BatchParam& param) { CHECK(page_); - auto begin_iter = BatchIterator( - std::make_shared>(page_)); + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(page_)); return BatchSet(begin_iter); } } // namespace data diff --git a/src/data/iterative_device_dmatrix.h b/src/data/iterative_device_dmatrix.h index 1097abb9c5cc..232b50102b56 100644 --- a/src/data/iterative_device_dmatrix.h +++ b/src/data/iterative_device_dmatrix.h @@ -83,8 +83,8 @@ inline void IterativeDeviceDMatrix::Initialize(DataIterHandle iter, float missin } inline BatchSet IterativeDeviceDMatrix::GetEllpackBatches(const BatchParam& param) { common::AssertGPUSupport(); - auto begin_iter = BatchIterator( - std::make_shared>(page_)); + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(page_)); return BatchSet(BatchIterator(begin_iter)); } #endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index cdd3a48d8174..a737c6d59071 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -48,7 +48,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. auto begin_iter = BatchIterator( - std::make_shared>(sparse_page_)); + new SimpleBatchIteratorImpl(sparse_page_)); return BatchSet(begin_iter); } @@ -57,8 +57,8 @@ BatchSet SimpleDMatrix::GetColumnBatches() { if (!column_page_) { column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_))); } - auto begin_iter = BatchIterator( - std::make_shared>(column_page_)); + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(column_page_)); return BatchSet(begin_iter); } @@ -70,8 +70,7 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { sorted_column_page_->SortRows(); } auto begin_iter = BatchIterator( - std::make_shared>( - sorted_column_page_)); + new SimpleBatchIteratorImpl(sorted_column_page_)); return BatchSet(begin_iter); } @@ -86,8 +85,8 @@ BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) ellpack_page_.reset(new EllpackPage(this, param)); batch_param_ = param; } - auto begin_iter = BatchIterator( - std::make_shared>(ellpack_page_)); + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_)); return BatchSet(begin_iter); } @@ -101,8 +100,7 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par batch_param_ = param; } auto begin_iter = BatchIterator( - std::make_shared>( - gradient_index_)); + new SimpleBatchIteratorImpl(gradient_index_)); return BatchSet(begin_iter); } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 8977dd0dffad..d6e26195b804 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -132,8 +132,7 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& } this->InitializeSparsePage(); auto begin_iter = BatchIterator( - std::make_shared>( - ghist_index_source_)); + new SimpleBatchIteratorImpl(ghist_index_source_)); return BatchSet(begin_iter); }