Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support host data in proxy DMatrix. #7087

Merged
merged 1 commit into from Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/data/adapter.h
Expand Up @@ -257,7 +257,10 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
Line const GetLine(size_t idx) const {
return Line{array_interface_, idx};
}
size_t Size() const { return array_interface_.num_rows; }

size_t NumRows() const { return array_interface_.num_rows; }
size_t NumCols() const { return array_interface_.num_cols; }
size_t Size() const { return this->NumRows(); }

explicit ArrayAdapterBatch(ArrayInterface array_interface)
: array_interface_{std::move(array_interface)} {}
Expand Down Expand Up @@ -288,6 +291,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
bst_feature_t n_features_;

class Line {
ArrayInterface indices_;
Expand All @@ -311,23 +315,27 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}
};

public:
static constexpr bool kIsRowMajor = true;

public:
CSRArrayAdapterBatch() = default;
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
ArrayInterface values, bst_feature_t n_features)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {
values_{std::move(values)}, n_features_{n_features} {
indptr_.AsColumnVector();
values_.AsColumnVector();
indices_.AsColumnVector();
}

size_t Size() const {
size_t NumRows() const {
size_t size = indptr_.num_rows * indptr_.num_cols;
size = size == 0 ? 0 : size - 1;
return size;
}
static constexpr bool kIsRowMajor = true;
size_t NumCols() const { return n_features_; }
size_t Size() const { return this->NumRows(); }

Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
Expand Down Expand Up @@ -356,7 +364,8 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
size_t num_cols)
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_,
static_cast<bst_feature_t>(num_cols_)};
}

CSRArrayAdapterBatch const& Value() const override {
Expand Down
20 changes: 20 additions & 0 deletions src/data/array_interface.h
Expand Up @@ -11,6 +11,7 @@
#include <map>
#include <string>
#include <utility>
#include <vector>

#include "xgboost/base.h"
#include "xgboost/data.h"
Expand Down Expand Up @@ -416,5 +417,24 @@ class ArrayInterface {
Type type;
};

template <typename T> std::string MakeArrayInterface(T const *data, size_t n) {
Json arr{Object{}};
arr["data"] = Array(std::vector<Json>{
Json{Integer{reinterpret_cast<int64_t>(data)}}, Json{Boolean{false}}});
arr["shape"] = Array{std::vector<Json>{Json{Integer{n}}, Json{Integer{1}}}};
std::string typestr;
if (DMLC_LITTLE_ENDIAN) {
typestr.push_back('<');
} else {
typestr.push_back('>');
}
typestr.push_back(ArrayInterfaceHandler::TypeChar<T>());
typestr += std::to_string(sizeof(T));
arr["typestr"] = typestr;
arr["version"] = 3;
std::string str;
Json::Dump(arr, &str);
return str;
}
} // namespace xgboost
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_
25 changes: 3 additions & 22 deletions src/data/iterative_device_dmatrix.cu
Expand Up @@ -11,35 +11,16 @@
#include "sparse_page_source.h"
#include "ellpack_page.cuh"
#include "proxy_dmatrix.h"
#include "proxy_dmatrix.cuh"
#include "device_adapter.cuh"

namespace xgboost {
namespace data {

template <typename Fn>
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}

void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) {
// A handle passed to external iterator.
auto handle = static_cast<std::shared_ptr<DMatrix>*>(proxy_);
CHECK(handle);
DMatrixProxy* proxy = static_cast<DMatrixProxy*>(handle->get());
DMatrixProxy* proxy = MakeProxy(proxy_);
CHECK(proxy);

// The external iterator
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
iter_handle, reset_, next_};
Expand Down
29 changes: 29 additions & 0 deletions src/data/proxy_dmatrix.cc
@@ -0,0 +1,29 @@
/*!
* Copyright 2021 by Contributors
* \file proxy_dmatrix.cc
*/

#include "proxy_dmatrix.h"

namespace xgboost {
namespace data {
void DMatrixProxy::SetArrayData(char const *c_interface) {
std::shared_ptr<ArrayAdapter> adapter{
new ArrayAdapter(StringView{c_interface})};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}

void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
char const *c_values, bst_feature_t n_features, bool on_host) {
CHECK(on_host) << "Not implemented on device.";
std::shared_ptr<CSRArrayAdapter> adapter{
new CSRArrayAdapter(StringView{c_indptr}, StringView{c_indices},
StringView{c_values}, n_features)};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}
} // namespace data
} // namespace xgboost
27 changes: 27 additions & 0 deletions src/data/proxy_dmatrix.cuh
@@ -0,0 +1,27 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#include "device_adapter.cuh"
#include "proxy_dmatrix.h"

namespace xgboost {
namespace data {
template <typename Fn>
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}
} // namespace data
} // namespace xgboost
40 changes: 40 additions & 0 deletions src/data/proxy_dmatrix.h
Expand Up @@ -72,6 +72,11 @@ class DMatrixProxy : public DMatrix {
#endif // defined(XGBOOST_USE_CUDA)
}

void SetArrayData(char const* c_interface);
void SetCSRData(char const *c_indptr, char const *c_indices,
char const *c_values, bst_feature_t n_features,
bool on_host);

MetaInfo& Info() override { return info_; }
MetaInfo const& Info() const override { return info_; }
bool SingleColBlock() const override { return true; }
Expand Down Expand Up @@ -106,6 +111,41 @@ class DMatrixProxy : public DMatrix {
return batch_;
}
};

inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) {
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy);
CHECK(proxy_handle) << "Invalid proxy handle.";
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get());
return typed;
}

template <typename Fn>
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
auto value =
dmlc::get<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
if (type_error) {
*type_error = false;
}
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
proxy->Adapter())->Value();
if (type_error) {
*type_error = false;
}
return fn(value);
} else {
if (type_error) {
*type_error = true;
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
23 changes: 22 additions & 1 deletion tests/cpp/data/test_adapter.cc
@@ -1,4 +1,4 @@
// Copyright (c) 2019 by Contributors
// Copyright (c) 2019-2021 by XGBoost Contributors
#include <gtest/gtest.h>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -35,6 +35,27 @@ TEST(Adapter, CSRAdapter) {
EXPECT_EQ(line2.GetElement(0).column_idx, 1);
}

TEST(Adapter, CSRArrayAdapter) {
HostDeviceVector<bst_row_t> indptr;
HostDeviceVector<float> values;
HostDeviceVector<bst_feature_t> indices;
size_t n_features = 100, n_samples = 10;
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices);
auto indptr_arr = MakeArrayInterface(indptr.HostPointer(), indptr.Size());
auto values_arr = MakeArrayInterface(values.HostPointer(), values.Size());
auto indices_arr = MakeArrayInterface(indices.HostPointer(), indices.Size());
auto adapter = data::CSRArrayAdapter(
StringView{indptr_arr.c_str(), indptr_arr.size()},
StringView{values_arr.c_str(), values_arr.size()},
StringView{indices_arr.c_str(), indices_arr.size()}, n_features);
auto batch = adapter.Value();
ASSERT_EQ(batch.NumRows(), n_samples);
ASSERT_EQ(batch.NumCols(), n_features);

ASSERT_EQ(adapter.NumRows(), n_samples);
ASSERT_EQ(adapter.NumColumns(), n_features);
}

TEST(Adapter, CSCAdapterColsMoreThanRows) {
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<unsigned> row_idx = {0, 1, 0, 1, 0, 1, 0, 1};
Expand Down
31 changes: 31 additions & 0 deletions tests/cpp/data/test_proxy_dmatrix.cc
@@ -0,0 +1,31 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../helpers.h"
#include "../../../src/data/proxy_dmatrix.h"
#include "../../../src/data/adapter.h"

namespace xgboost {
namespace data {
TEST(ProxyDMatrix, HostData) {
DMatrixProxy proxy;
size_t constexpr kRows = 100, kCols = 10;
std::vector<HostDeviceVector<float>> label_storage(1);

HostDeviceVector<float> storage;
auto data = RandomDataGenerator(kRows, kCols, 0.5)
.Device(0)
.GenerateArrayInterface(&storage);

proxy.SetArrayData(data.c_str());

auto n_samples = HostAdapterDispatch(
&proxy, [](auto const &value) { return value.Size(); });
ASSERT_EQ(n_samples, kRows);
auto n_features = HostAdapterDispatch(
&proxy, [](auto const &value) { return value.NumCols(); });
ASSERT_EQ(n_features, kCols);
}
} // namespace data
} // namespace xgboost
2 changes: 1 addition & 1 deletion tests/cpp/data/test_proxy_dmatrix.cu
Expand Up @@ -7,7 +7,7 @@

namespace xgboost {
namespace data {
TEST(ProxyDMatrix, Basic) {
TEST(ProxyDMatrix, DeviceData) {
constexpr size_t kRows{100}, kCols{100};
HostDeviceVector<float> storage;
auto data = RandomDataGenerator(kRows, kCols, 0.5)
Expand Down