Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor external memory formats. #7089

Merged
merged 1 commit into from Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/data/ellpack_page_raw_format.cu
@@ -1,7 +1,6 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2021 XGBoost contributors
*/

#include <xgboost/data.h>
#include <dmlc/registry.h>

Expand All @@ -13,6 +12,7 @@ namespace data {

DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);


class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
public:
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
Expand All @@ -23,29 +23,34 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
fi->Read(&impl->n_rows);
fi->Read(&impl->is_dense);
fi->Read(&impl->row_stride);
if (!fi->Read(&impl->gidx_buffer.HostVector())) {
fi->Read(&impl->gidx_buffer.HostVector());
if (!fi->Read(&impl->base_rowid)) {
return false;
}
return true;
}

bool Read(EllpackPage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) override {
LOG(FATAL) << "Not implemented";
return false;
}

void Write(const EllpackPage& page, dmlc::Stream* fo) override {
size_t Write(const EllpackPage& page, dmlc::Stream* fo) override {
size_t bytes = 0;
auto* impl = page.Impl();
fo->Write(impl->Cuts().cut_values_.ConstHostVector());
bytes += impl->Cuts().cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector());
bytes += impl->Cuts().cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->Cuts().min_vals_.ConstHostVector());
bytes += impl->Cuts().min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->n_rows);
bytes += sizeof(impl->n_rows);
fo->Write(impl->is_dense);
bytes += sizeof(impl->is_dense);
fo->Write(impl->row_stride);
bytes += sizeof(impl->row_stride);
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
fo->Write(impl->gidx_buffer.HostVector());
bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t);
fo->Write(impl->base_rowid);
bytes += sizeof(impl->base_rowid);
return bytes;
}
};

Expand Down
72 changes: 16 additions & 56 deletions src/data/sparse_page_raw_format.cc
@@ -1,10 +1,12 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015-2021 by Contributors
* \file sparse_page_raw_format.cc
* Raw binary format of sparse page.
*/
#include <xgboost/data.h>
#include <dmlc/registry.h>

#include "xgboost/logging.h"
#include "./sparse_page_writer.h"

namespace xgboost {
Expand All @@ -17,78 +19,36 @@ class SparsePageRawFormat : public SparsePageFormat<T> {
public:
bool Read(T* page, dmlc::SeekStream* fi) override {
auto& offset_vec = page->offset.HostVector();
if (!fi->Read(&offset_vec)) return false;
if (!fi->Read(&offset_vec)) {
return false;
}
auto& data_vec = page->data.HostVector();
CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file";
data_vec.resize(offset_vec.back());
if (page->data.Size() != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec),
(page->data).Size() * sizeof(Entry)),
(page->data).Size() * sizeof(Entry))
size_t n_bytes = fi->Read(dmlc::BeginPtr(data_vec),
(page->data).Size() * sizeof(Entry));
CHECK_EQ(n_bytes, (page->data).Size() * sizeof(Entry))
<< "Invalid SparsePage file";
}
fi->Read(&page->base_rowid, sizeof(page->base_rowid));
return true;
}

bool Read(T* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) override {
if (!fi->Read(&disk_offset_)) return false;
auto& offset_vec = page->offset.HostVector();
auto& data_vec = page->data.HostVector();
// setup the offset
offset_vec.clear();
offset_vec.push_back(0);
for (unsigned int fid : sorted_index_set) {
CHECK_LT(fid + 1, disk_offset_.size());
size_t size = disk_offset_[fid + 1] - disk_offset_[fid];
offset_vec.push_back(offset_vec.back() + size);
}
data_vec.resize(offset_vec.back());
// read in the data
size_t begin = fi->Tell();
size_t curr_offset = 0;
for (size_t i = 0; i < sorted_index_set.size();) {
bst_uint fid = sorted_index_set[i];
if (disk_offset_[fid] != curr_offset) {
CHECK_GT(disk_offset_[fid], curr_offset);
fi->Seek(begin + disk_offset_[fid] * sizeof(Entry));
curr_offset = disk_offset_[fid];
}
size_t j, size_to_read = 0;
for (j = i; j < sorted_index_set.size(); ++j) {
if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) {
size_to_read += offset_vec[j + 1] - offset_vec[j];
} else {
break;
}
}

if (size_to_read != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec) + offset_vec[i],
size_to_read * sizeof(Entry)),
size_to_read * sizeof(Entry))
<< "Invalid SparsePage file";
curr_offset += size_to_read;
}
i = j;
}
// seek to end of record
if (curr_offset != disk_offset_.back()) {
fi->Seek(begin + disk_offset_.back() * sizeof(Entry));
}
return true;
}

void Write(const T& page, dmlc::Stream* fo) override {
size_t Write(const T& page, dmlc::Stream* fo) override {
const auto& offset_vec = page.offset.HostVector();
const auto& data_vec = page.data.HostVector();
CHECK(page.offset.Size() != 0 && offset_vec[0] == 0);
CHECK_EQ(offset_vec.back(), page.data.Size());
fo->Write(offset_vec);
auto bytes = page.MemCostBytes();
bytes += sizeof(uint64_t);
if (page.data.Size() != 0) {
fo->Write(dmlc::BeginPtr(data_vec), page.data.Size() * sizeof(Entry));
}
fo->Write(&page.base_rowid, sizeof(page.base_rowid));
bytes += sizeof(page.base_rowid);
return bytes;
}

private:
Expand Down
13 changes: 1 addition & 12 deletions src/data/sparse_page_writer.h
Expand Up @@ -42,22 +42,11 @@ class SparsePageFormat {
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(T* page, dmlc::SeekStream* fi) = 0;

/*!
* \brief read only the segments we are interested in, advance fi to end of the block.
* \param page The page to load the data into.
* \param fi the input stream of the file
* \param sorted_index_set sorted index of segments we are interested in
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(T* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) = 0;
/*!
* \brief save the data to fo, when a page was written.
* \param fo output stream
*/
virtual void Write(const T& page, dmlc::Stream* fo) = 0;
virtual size_t Write(const T& page, dmlc::Stream* fo) = 0;
};

/*!
Expand Down
45 changes: 45 additions & 0 deletions tests/cpp/data/test_ellpack_page_raw_format.cu
@@ -0,0 +1,45 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/data.h>

#include "../../../src/data/sparse_page_source.h"
#include "../../../src/data/ellpack_page.cuh"

#include "../helpers.h"

namespace xgboost {
namespace data {
TEST(EllpackPageRawFormat, IO) {
std::unique_ptr<SparsePageFormat<EllpackPage>> format{CreatePageFormat<EllpackPage>("raw")};

auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/ellpack.page";

{
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (auto const &ellpack : m->GetBatches<EllpackPage>({0, 256})) {
format->Write(ellpack, fo.get());
}
}

EllpackPage page;
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
format->Read(&page, fi.get());

for (auto const &ellpack : m->GetBatches<EllpackPage>({0, 256})) {
auto loaded = page.Impl();
auto orig = ellpack.Impl();
ASSERT_EQ(loaded->Cuts().Ptrs(), orig->Cuts().Ptrs());
ASSERT_EQ(loaded->Cuts().MinValues(), orig->Cuts().MinValues());
ASSERT_EQ(loaded->Cuts().Values(), orig->Cuts().Values());
ASSERT_EQ(loaded->base_rowid, orig->base_rowid);
ASSERT_EQ(loaded->row_stride, orig->row_stride);
ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector());
}
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace data
} // namespace xgboost
56 changes: 56 additions & 0 deletions tests/cpp/data/test_sparse_page_raw_format.cc
@@ -0,0 +1,56 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/data.h>

#include "../../../src/data/sparse_page_source.h"
#include "../helpers.h"

namespace xgboost {
namespace data {
template <typename S> void TestSparsePageRawFormat() {
std::unique_ptr<SparsePageFormat<S>> format{CreatePageFormat<S>("raw")};

auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
ASSERT_TRUE(m->SingleColBlock());
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/sparse.page";
S orig;
{
// block code to flush the stream
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (auto const &page : m->GetBatches<S>()) {
orig.Push(page);
format->Write(page, fo.get());
}
}

S page;
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
format->Read(&page, fi.get());
for (size_t i = 0; i < orig.data.Size(); ++i) {
ASSERT_EQ(page.data.HostVector()[i].fvalue,
orig.data.HostVector()[i].fvalue);
ASSERT_EQ(page.data.HostVector()[i].index, orig.data.HostVector()[i].index);
}
for (size_t i = 0; i < orig.offset.Size(); ++i) {
ASSERT_EQ(page.offset.HostVector()[i], orig.offset.HostVector()[i]);
}
ASSERT_EQ(page.base_rowid, orig.base_rowid);
}

TEST(SparsePageRawFormat, SparsePage) {
TestSparsePageRawFormat<SparsePage>();
}

TEST(SparsePageRawFormat, CSCPage) {
TestSparsePageRawFormat<CSCPage>();
}

TEST(SparsePageRawFormat, SortedCSCPage) {
TestSparsePageRawFormat<SortedCSCPage>();
}
} // namespace data
} // namespace xgboost