Skip to content

Commit

Permalink
Support null value in CUDA array interface. (#8486) (#8499)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Nov 30, 2022
1 parent 9372370 commit db14e3f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
26 changes: 16 additions & 10 deletions src/data/array_interface.h
Expand Up @@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
auto data_it = obj.find("data");
if (data_it == obj.cend()) {
if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Empty data passed in.";
}
auto p_data = reinterpret_cast<PtrType>(
Expand All @@ -111,25 +111,27 @@ class ArrayInterfaceHandler {

static void Validate(Object::Map const &array) {
auto version_it = array.find("version");
if (version_it == array.cend()) {
if (version_it == array.cend() || IsA<Null>(version_it->second)) {
LOG(FATAL) << "Missing `version' field for array interface";
}
if (get<Integer const>(version_it->second) > 3) {
LOG(FATAL) << ArrayInterfaceErrors::Version();
}

auto typestr_it = array.find("typestr");
if (typestr_it == array.cend()) {
if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}

auto typestr = get<String const>(typestr_it->second);
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();

if (array.find("shape") == array.cend()) {
auto shape_it = array.find("shape");
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
LOG(FATAL) << "Missing `shape' field for array interface";
}
if (array.find("data") == array.cend()) {
auto data_it = array.find("data");
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Missing `data' field for array interface";
}
}
Expand All @@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
static size_t ExtractMask(Object::Map const &column,
common::Span<RBitField8::value_type> *p_out) {
auto &s_mask = *p_out;
if (column.find("mask") != column.cend()) {
auto const &j_mask = get<Object const>(column.at("mask"));
auto const &mask_it = column.find("mask");
if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
auto const &j_mask = get<Object const>(mask_it->second);
Validate(j_mask);

auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
Expand Down Expand Up @@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
// assume 1 byte alignment.
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);

if (j_mask.find("strides") != j_mask.cend()) {
auto strides = get<Array const>(column.at("strides"));
auto strides_it = j_mask.find("strides");
if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
auto strides = get<Array const>(strides_it->second);
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
}
Expand Down Expand Up @@ -401,7 +405,9 @@ class ArrayInterface {
<< "XGBoost doesn't support internal broadcasting.";
}
} else {
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
auto mask_it = array.find("mask");
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
<< "Masked array is not yet supported.";
}

auto stream_it = array.find("stream");
Expand Down
14 changes: 11 additions & 3 deletions tests/cpp/data/test_array_interface.cc
Expand Up @@ -33,9 +33,8 @@ TEST(ArrayInterface, Error) {
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))};
std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))};

auto const& column_obj = get<Object>(column);
std::string typestr{"<f4"};
Expand All @@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
column["version"] = 3;
// missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error);
// null data
column["data"] = Null{};
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error);
column["data"] = j_data;
Expand All @@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
Json(Boolean(false))};
column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
// null data in mask
column["mask"] = Object{};
column["mask"]["data"] = Null{};
common::Span<RBitField8::value_type> s_mask;
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
}

TEST(ArrayInterface, GetElement) {
Expand Down

0 comments on commit db14e3f

Please sign in to comment.