From f079bb3001c7cd3cf07d31249d46d8d34f77f5b1 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Nov 2022 22:55:19 +0800 Subject: [PATCH 1/3] Support null value in CUDA array interface. - Fix for potential null value in array interface. - Fix incorrect check on mask stride. --- src/data/array_interface.h | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 0d7b4681accb..e755108069dc 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -101,7 +101,7 @@ class ArrayInterfaceHandler { template static PtrType GetPtrFromArrayData(Object::Map const &obj) { auto data_it = obj.find("data"); - if (data_it == obj.cend()) { + if (data_it == obj.cend() || IsA(data_it->second)) { LOG(FATAL) << "Empty data passed in."; } auto p_data = reinterpret_cast( @@ -111,7 +111,7 @@ class ArrayInterfaceHandler { static void Validate(Object::Map const &array) { auto version_it = array.find("version"); - if (version_it == array.cend()) { + if (version_it == array.cend() || IsA(version_it->second)) { LOG(FATAL) << "Missing `version' field for array interface"; } if (get(version_it->second) > 3) { @@ -119,17 +119,19 @@ class ArrayInterfaceHandler { } auto typestr_it = array.find("typestr"); - if (typestr_it == array.cend()) { + if (typestr_it == array.cend() || IsA(typestr_it->second)) { LOG(FATAL) << "Missing `typestr' field for array interface"; } auto typestr = get(typestr_it->second); CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat(); - if (array.find("shape") == array.cend()) { + auto shape_it = array.find("shape"); + if (shape_it == array.cend() || IsA(shape_it->second)) { LOG(FATAL) << "Missing `shape' field for array interface"; } - if (array.find("data") == array.cend()) { + auto data_it = array.find("data"); + if (data_it == array.cend() || IsA(data_it->second)) { LOG(FATAL) << "Missing `data' field for array interface"; } } @@ -139,8 +141,9 @@ class ArrayInterfaceHandler { static size_t ExtractMask(Object::Map const &column, common::Span *p_out) { auto &s_mask = *p_out; - if (column.find("mask") != column.cend()) { - auto const &j_mask = get(column.at("mask")); + auto const &mask_it = column.find("mask"); + if (mask_it != column.cend() && !IsA(mask_it->second)) { + auto const &j_mask = get(mask_it->second); Validate(j_mask); auto p_mask = GetPtrFromArrayData(j_mask); @@ -173,8 +176,9 @@ class ArrayInterfaceHandler { // assume 1 byte alignment. size_t const span_size = RBitField8::ComputeStorageSize(n_bits); - if (j_mask.find("strides") != j_mask.cend()) { - auto strides = get(column.at("strides")); + auto strides_it = j_mask.find("strides"); + if (strides_it != j_mask.cend() && !IsA(strides_it->second)) { + auto strides = get(strides_it->second); CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); CHECK_EQ(get(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous(); } @@ -401,7 +405,9 @@ class ArrayInterface { << "XGBoost doesn't support internal broadcasting."; } } else { - CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported."; + auto mask_it = array.find("mask"); + CHECK(mask_it == array.cend() || IsA(mask_it->second)) + << "Masked array is not yet supported."; } auto stream_it = array.find("stream"); From 0fe51e4c0754de1f9cca426e987f6c33a8c7067b Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 28 Nov 2022 20:32:25 +0800 Subject: [PATCH 2/3] Simple tests. --- tests/cpp/data/test_array_interface.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc index 3c2e0e38d5c3..4e321fdc2ee1 100644 --- a/tests/cpp/data/test_array_interface.cc +++ b/tests/cpp/data/test_array_interface.cc @@ -33,9 +33,8 @@ TEST(ArrayInterface, Error) { Json column { Object() }; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); - std::vector j_data { - Json(Integer(reinterpret_cast(nullptr))), - Json(Boolean(false))}; + std::vector j_data{Json(Integer(reinterpret_cast(nullptr))), + Json(Boolean(false))}; auto const& column_obj = get(column); std::string typestr{" storage; auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); From 566a88e86483d908bcd1e86135907db664afffc6 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 28 Nov 2022 22:44:15 +0800 Subject: [PATCH 3/3] Extract mask. --- tests/cpp/data/test_array_interface.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc index 4e321fdc2ee1..5bd771ff08e2 100644 --- a/tests/cpp/data/test_array_interface.cc +++ b/tests/cpp/data/test_array_interface.cc @@ -58,11 +58,6 @@ TEST(ArrayInterface, Error) { // nullptr is not valid EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); - column["mask"] = Object{}; - column["mask"]["data"] = Null{}; - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), - dmlc::Error); - column["mask"] = Null{}; HostDeviceVector storage; auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); @@ -71,6 +66,11 @@ TEST(ArrayInterface, Error) { Json(Boolean(false))}; column["data"] = j_data; EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n)); + // null data in mask + column["mask"] = Object{}; + column["mask"]["data"] = Null{}; + common::Span s_mask; + EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error); } TEST(ArrayInterface, GetElement) {