Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] Support null value in CUDA array interface. (#8486) #8499

Merged
merged 1 commit into from Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/data/array_interface.h
Expand Up @@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
auto data_it = obj.find("data");
if (data_it == obj.cend()) {
if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Empty data passed in.";
}
auto p_data = reinterpret_cast<PtrType>(
Expand All @@ -111,25 +111,27 @@ class ArrayInterfaceHandler {

static void Validate(Object::Map const &array) {
auto version_it = array.find("version");
if (version_it == array.cend()) {
if (version_it == array.cend() || IsA<Null>(version_it->second)) {
LOG(FATAL) << "Missing `version' field for array interface";
}
if (get<Integer const>(version_it->second) > 3) {
LOG(FATAL) << ArrayInterfaceErrors::Version();
}

auto typestr_it = array.find("typestr");
if (typestr_it == array.cend()) {
if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}

auto typestr = get<String const>(typestr_it->second);
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();

if (array.find("shape") == array.cend()) {
auto shape_it = array.find("shape");
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
LOG(FATAL) << "Missing `shape' field for array interface";
}
if (array.find("data") == array.cend()) {
auto data_it = array.find("data");
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Missing `data' field for array interface";
}
}
Expand All @@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
static size_t ExtractMask(Object::Map const &column,
common::Span<RBitField8::value_type> *p_out) {
auto &s_mask = *p_out;
if (column.find("mask") != column.cend()) {
auto const &j_mask = get<Object const>(column.at("mask"));
auto const &mask_it = column.find("mask");
if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
auto const &j_mask = get<Object const>(mask_it->second);
Validate(j_mask);

auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
Expand Down Expand Up @@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
// assume 1 byte alignment.
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);

if (j_mask.find("strides") != j_mask.cend()) {
auto strides = get<Array const>(column.at("strides"));
auto strides_it = j_mask.find("strides");
if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
auto strides = get<Array const>(strides_it->second);
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
}
Expand Down Expand Up @@ -401,7 +405,9 @@ class ArrayInterface {
<< "XGBoost doesn't support internal broadcasting.";
}
} else {
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
auto mask_it = array.find("mask");
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
<< "Masked array is not yet supported.";
}

auto stream_it = array.find("stream");
Expand Down
14 changes: 11 additions & 3 deletions tests/cpp/data/test_array_interface.cc
Expand Up @@ -33,9 +33,8 @@ TEST(ArrayInterface, Error) {
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))};
std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))};

auto const& column_obj = get<Object>(column);
std::string typestr{"<f4"};
Expand All @@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
column["version"] = 3;
// missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error);
// null data
column["data"] = Null{};
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error);
column["data"] = j_data;
Expand All @@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
Json(Boolean(false))};
column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
// null data in mask
column["mask"] = Object{};
column["mask"]["data"] = Null{};
common::Span<RBitField8::value_type> s_mask;
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
}

TEST(ArrayInterface, GetElement) {
Expand Down