diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index d4caf37e2be3..13a0a1766d32 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -1,7 +1,6 @@ /*! - * Copyright 2019 XGBoost contributors + * Copyright 2019-2021 XGBoost contributors */ - #include #include @@ -13,6 +12,7 @@ namespace data { DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); + class EllpackPageRawFormat : public SparsePageFormat { public: bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { @@ -23,29 +23,34 @@ class EllpackPageRawFormat : public SparsePageFormat { fi->Read(&impl->n_rows); fi->Read(&impl->is_dense); fi->Read(&impl->row_stride); - if (!fi->Read(&impl->gidx_buffer.HostVector())) { + fi->Read(&impl->gidx_buffer.HostVector()); + if (!fi->Read(&impl->base_rowid)) { return false; } return true; } - bool Read(EllpackPage* page, - dmlc::SeekStream* fi, - const std::vector& sorted_index_set) override { - LOG(FATAL) << "Not implemented"; - return false; - } - - void Write(const EllpackPage& page, dmlc::Stream* fo) override { + size_t Write(const EllpackPage& page, dmlc::Stream* fo) override { + size_t bytes = 0; auto* impl = page.Impl(); fo->Write(impl->Cuts().cut_values_.ConstHostVector()); + bytes += impl->Cuts().cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t); fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector()); + bytes += impl->Cuts().cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t); fo->Write(impl->Cuts().min_vals_.ConstHostVector()); + bytes += impl->Cuts().min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t); fo->Write(impl->n_rows); + bytes += sizeof(impl->n_rows); fo->Write(impl->is_dense); + bytes += sizeof(impl->is_dense); fo->Write(impl->row_stride); + bytes += sizeof(impl->row_stride); CHECK(!impl->gidx_buffer.ConstHostVector().empty()); fo->Write(impl->gidx_buffer.HostVector()); + bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t); + fo->Write(impl->base_rowid); + bytes += sizeof(impl->base_rowid); + return bytes; } }; diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc index b9a82bbdfa70..1e5d1ec71f72 100644 --- a/src/data/sparse_page_raw_format.cc +++ b/src/data/sparse_page_raw_format.cc @@ -1,10 +1,12 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015-2021 by Contributors * \file sparse_page_raw_format.cc * Raw binary format of sparse page. */ #include #include + +#include "xgboost/logging.h" #include "./sparse_page_writer.h" namespace xgboost { @@ -17,78 +19,36 @@ class SparsePageRawFormat : public SparsePageFormat { public: bool Read(T* page, dmlc::SeekStream* fi) override { auto& offset_vec = page->offset.HostVector(); - if (!fi->Read(&offset_vec)) return false; + if (!fi->Read(&offset_vec)) { + return false; + } auto& data_vec = page->data.HostVector(); CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file"; data_vec.resize(offset_vec.back()); if (page->data.Size() != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec), - (page->data).Size() * sizeof(Entry)), - (page->data).Size() * sizeof(Entry)) + size_t n_bytes = fi->Read(dmlc::BeginPtr(data_vec), + (page->data).Size() * sizeof(Entry)); + CHECK_EQ(n_bytes, (page->data).Size() * sizeof(Entry)) << "Invalid SparsePage file"; } + fi->Read(&page->base_rowid, sizeof(page->base_rowid)); return true; } - bool Read(T* page, - dmlc::SeekStream* fi, - const std::vector& sorted_index_set) override { - if (!fi->Read(&disk_offset_)) return false; - auto& offset_vec = page->offset.HostVector(); - auto& data_vec = page->data.HostVector(); - // setup the offset - offset_vec.clear(); - offset_vec.push_back(0); - for (unsigned int fid : sorted_index_set) { - CHECK_LT(fid + 1, disk_offset_.size()); - size_t size = disk_offset_[fid + 1] - disk_offset_[fid]; - offset_vec.push_back(offset_vec.back() + size); - } - data_vec.resize(offset_vec.back()); - // read in the data - size_t begin = fi->Tell(); - size_t curr_offset = 0; - for (size_t i = 0; i < sorted_index_set.size();) { - bst_uint fid = sorted_index_set[i]; - if (disk_offset_[fid] != curr_offset) { - CHECK_GT(disk_offset_[fid], curr_offset); - fi->Seek(begin + disk_offset_[fid] * sizeof(Entry)); - curr_offset = disk_offset_[fid]; - } - size_t j, size_to_read = 0; - for (j = i; j < sorted_index_set.size(); ++j) { - if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) { - size_to_read += offset_vec[j + 1] - offset_vec[j]; - } else { - break; - } - } - - if (size_to_read != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec) + offset_vec[i], - size_to_read * sizeof(Entry)), - size_to_read * sizeof(Entry)) - << "Invalid SparsePage file"; - curr_offset += size_to_read; - } - i = j; - } - // seek to end of record - if (curr_offset != disk_offset_.back()) { - fi->Seek(begin + disk_offset_.back() * sizeof(Entry)); - } - return true; - } - - void Write(const T& page, dmlc::Stream* fo) override { + size_t Write(const T& page, dmlc::Stream* fo) override { const auto& offset_vec = page.offset.HostVector(); const auto& data_vec = page.data.HostVector(); CHECK(page.offset.Size() != 0 && offset_vec[0] == 0); CHECK_EQ(offset_vec.back(), page.data.Size()); fo->Write(offset_vec); + auto bytes = page.MemCostBytes(); + bytes += sizeof(uint64_t); if (page.data.Size() != 0) { fo->Write(dmlc::BeginPtr(data_vec), page.data.Size() * sizeof(Entry)); } + fo->Write(&page.base_rowid, sizeof(page.base_rowid)); + bytes += sizeof(page.base_rowid); + return bytes; } private: diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index f63fcf0f8d7d..2b079fb6c680 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -42,22 +42,11 @@ class SparsePageFormat { * \return true of the loading as successful, false if end of file was reached */ virtual bool Read(T* page, dmlc::SeekStream* fi) = 0; - - /*! - * \brief read only the segments we are interested in, advance fi to end of the block. - * \param page The page to load the data into. - * \param fi the input stream of the file - * \param sorted_index_set sorted index of segments we are interested in - * \return true of the loading as successful, false if end of file was reached - */ - virtual bool Read(T* page, - dmlc::SeekStream* fi, - const std::vector& sorted_index_set) = 0; /*! * \brief save the data to fo, when a page was written. * \param fo output stream */ - virtual void Write(const T& page, dmlc::Stream* fo) = 0; + virtual size_t Write(const T& page, dmlc::Stream* fo) = 0; }; /*! diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu new file mode 100644 index 000000000000..d4b5722eabf6 --- /dev/null +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -0,0 +1,45 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include +#include + +#include "../../../src/data/sparse_page_source.h" +#include "../../../src/data/ellpack_page.cuh" + +#include "../helpers.h" + +namespace xgboost { +namespace data { +TEST(EllpackPageRawFormat, IO) { + std::unique_ptr> format{CreatePageFormat("raw")}; + + auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/ellpack.page"; + + { + std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + for (auto const &ellpack : m->GetBatches({0, 256})) { + format->Write(ellpack, fo.get()); + } + } + + EllpackPage page; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; + format->Read(&page, fi.get()); + + for (auto const &ellpack : m->GetBatches({0, 256})) { + auto loaded = page.Impl(); + auto orig = ellpack.Impl(); + ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs()); + ASSERT_EQ(loaded->Cuts().MinValues(), orig->Cuts().MinValues()); + ASSERT_EQ(loaded->Cuts().Values(), orig->Cuts().Values()); + ASSERT_EQ(loaded->base_rowid, orig->base_rowid); + ASSERT_EQ(loaded->row_stride, orig->row_stride); + ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector()); + } +} +} // namespace data +} // namespace xgboost diff --git a/tests/cpp/data/test_sparse_page_raw_format.cc b/tests/cpp/data/test_sparse_page_raw_format.cc new file mode 100644 index 000000000000..dc7c5b2be77f --- /dev/null +++ b/tests/cpp/data/test_sparse_page_raw_format.cc @@ -0,0 +1,56 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include +#include + +#include "../../../src/data/sparse_page_source.h" +#include "../helpers.h" + +namespace xgboost { +namespace data { +template void TestSparsePageRawFormat() { + std::unique_ptr> format{CreatePageFormat("raw")}; + + auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); + ASSERT_TRUE(m->SingleColBlock()); + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/sparse.page"; + S orig; + { + // block code to flush the stream + std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + for (auto const &page : m->GetBatches()) { + orig.Push(page); + format->Write(page, fo.get()); + } + } + + S page; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; + format->Read(&page, fi.get()); + for (size_t i = 0; i < orig.data.Size(); ++i) { + ASSERT_EQ(page.data.HostVector()[i].fvalue, + orig.data.HostVector()[i].fvalue); + ASSERT_EQ(page.data.HostVector()[i].index, orig.data.HostVector()[i].index); + } + for (size_t i = 0; i < orig.offset.Size(); ++i) { + ASSERT_EQ(page.offset.HostVector()[i], orig.offset.HostVector()[i]); + } + ASSERT_EQ(page.base_rowid, orig.base_rowid); +} + +TEST(SparsePageRawFormat, SparsePage) { + TestSparsePageRawFormat(); +} + +TEST(SparsePageRawFormat, CSCPage) { + TestSparsePageRawFormat(); +} + +TEST(SparsePageRawFormat, SortedCSCPage) { + TestSparsePageRawFormat(); +} +} // namespace data +} // namespace xgboost