diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst index 6d407ba129e9..b4880803cb0c 100644 --- a/doc/contrib/coding_guide.rst +++ b/doc/contrib/coding_guide.rst @@ -134,3 +134,49 @@ Similarly, if you want to exclude C++ source from linting: cd /path/to/xgboost/ python3 tests/ci_build/tidy.py --cpp=0 +********************************** +Guide for handling user input data +********************************** + +This is an in-comprehensive guide for handling user input data. XGBoost has wide verity +of native supported data structures, mostly come from higher level language bindings. The +inputs ranges from basic contiguous 1 dimension memory buffer to more sophisticated data +structures like columnar data with validity mask. Raw input data can be used in 2 places, +firstly it's the construction of various ``DMatrix``, secondly it's the in-place +prediction. For plain memory buffer, there's not much to discuss since it's just a +pointer with a size. But for general n-dimension array and columnar data, there are many +subtleties. XGBoost has 3 different data structures for handling optionally masked arrays +(tensors), for consuming user inputs ``ArrayInterface`` should be chosen. There are many +existing functions that accept only plain pointer due to legacy reasons (XGBoost started +as a much simpler library and didn't care about memory usage that much back then). The +``ArrayInterface`` is a in memory representation of ``__array_interface__`` protocol +defined by numpy or the ``__cuda_array_interface__`` defined by numba. Following is a +check list of things to have in mind when accepting related user inputs: + +- [ ] Is it strided? (identified by the ``strides`` field) +- [ ] If it's a vector, is it row vector or column vector? (Identified by both ``shape`` + and ``strides``). +- [ ] Is the data type supported? Half type and 128 integer types should be converted + before going into XGBoost. +- [ ] Does it have higher than 1 dimension? (identified by ``shape`` field) +- [ ] Are some of dimensions trivial? (shape[dim] <= 1) +- [ ] Does it have mask? (identified by ``mask`` field) +- [ ] Can the mask be broadcasted? (unsupported at the moment) +- [ ] Is it on CUDA memory? (identified by ``data`` field, and optionally ``stream``) + +Most of the checks are handled by the ``ArrayInterface`` during construction, except for +the data type issue since it doesn't know how to cast such pointers with C builtin types. +But for safety reason one should still try to write related tests for the all items. The +data type issue should be taken care of in language binding for each of the specific data +input. For single-chunk columnar format, it's just a masked array for each column so it +should be treated uniformly as normal array. For input predictor ``X``, we have adapters +for each type of input. Some are composition of the others. For instance, CSR matrix has 3 +potentially strided arrays for ``indptr``, ``indices`` and ``values``. No assumption +should be made to these components (all the check boxes should be considered). Slicing row +of CSR matrix should calculate the offset of each field based on respective strides. + +For meta info like labels, which is growing both in size and complexity, we accept only +masked array at the moment (no specialized adapter). One should be careful about the +input data shape. For base margin it can be 2 dim or higher if we have multiple targets in +the future. The getters in ``DMatrix`` returns only 1 dimension flatten vectors at the +moment, which can be improved in the future when it's needed. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 8d929593c1a4..9d657872e26b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -249,7 +249,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const* json_config, DMatrixHandle *out); -/* +/** * ========================== Begin data callback APIs ========================= * * Short notes for data callback @@ -258,9 +258,9 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated * input, and concatenate them into 1 final big CSR. The related functions are: * - * - XGBCallbackSetData - * - XGBCallbackDataIterNext - * - XGDMatrixCreateFromDataIter + * - \ref XGBCallbackSetData + * - \ref XGBCallbackDataIterNext + * - \ref XGDMatrixCreateFromDataIter * * Another set is used by external data iterator. It accept foreign data iterators as * callbacks. There are 2 different senarios where users might want to pass in callbacks @@ -276,17 +276,17 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * Related functions are: * * # Factory functions - * - `XGDMatrixCreateFromCallback` for external memory - * - `XGDeviceQuantileDMatrixCreateFromCallback` for quantile DMatrix + * - \ref XGDMatrixCreateFromCallback for external memory + * - \ref XGDeviceQuantileDMatrixCreateFromCallback for quantile DMatrix * * # Proxy that callers can use to pass data to XGBoost - * - XGProxyDMatrixCreate - * - XGDMatrixCallbackNext - * - DataIterResetCallback - * - XGProxyDMatrixSetDataCudaArrayInterface - * - XGProxyDMatrixSetDataCudaColumnar - * - XGProxyDMatrixSetDataDense - * - XGProxyDMatrixSetDataCSR + * - \ref XGProxyDMatrixCreate + * - \ref XGDMatrixCallbackNext + * - \ref DataIterResetCallback + * - \ref XGProxyDMatrixSetDataCudaArrayInterface + * - \ref XGProxyDMatrixSetDataCudaColumnar + * - \ref XGProxyDMatrixSetDataDense + * - \ref XGProxyDMatrixSetDataCSR * - ... (data setters) */ @@ -411,7 +411,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. * - nthread (optional): Number of threads used for initializing DMatrix. * - * \param out The created external memory DMatrix + * \param[out] out The created external memory DMatrix * * \return 0 when success, -1 when failure happens */ @@ -605,7 +605,8 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, * char const* feat_names [] {"feat_0", "feat_1"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_name", feat_names, 2); * - * // i for integer, q for quantitive. Similarly "int" and "float" are also recognized. + * // i for integer, q for quantitive, c for categorical. Similarly "int" and "float" + * // are also recognized. * char const* feat_types [] {"i", "q"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, 2); * diff --git a/include/xgboost/data.h b/include/xgboost/data.h index eb3977f989b4..c91451678856 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -47,7 +47,7 @@ enum class FeatureType : uint8_t { class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 11; + static constexpr uint64_t kNumField = 12; /*! \brief number of rows in the data */ uint64_t num_row_{0}; // NOLINT @@ -69,7 +69,7 @@ class MetaInfo { * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from. */ - HostDeviceVector base_margin_; // NOLINT + linalg::Tensor base_margin_; // NOLINT /*! * \brief lower bound of the label, to be used for survival analysis (censored regression) */ @@ -154,12 +154,8 @@ class MetaInfo { * \brief Set information in the meta info with array interface. * \param key The key of the information. * \param interface_str String representation of json format array interface. - * - * [ column_0, column_1, ... column_n ] - * - * Right now only 1 column is permitted. */ - void SetInfo(StringView key, std::string const& interface_str); + void SetInfo(StringView key, StringView interface_str); void GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const; @@ -181,6 +177,9 @@ class MetaInfo { void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); private: + void SetInfoFromHost(StringView key, Json arr); + void SetInfoFromCUDA(StringView key, Json arr); + /*! \brief argsort of labels */ mutable std::vector label_order_cache_; }; @@ -479,7 +478,7 @@ class DMatrix { this->Info().SetInfo(key, dptr, dtype, num); } virtual void SetInfo(const char* key, std::string const& interface_str) { - this->Info().SetInfo(key, interface_str); + this->Info().SetInfo(key, StringView{interface_str}); } /*! \brief meta information of the dataset */ virtual const MetaInfo& Info() const = 0; diff --git a/include/xgboost/intrusive_ptr.h b/include/xgboost/intrusive_ptr.h index 1c58704c4ca3..df7ae30213d6 100644 --- a/include/xgboost/intrusive_ptr.h +++ b/include/xgboost/intrusive_ptr.h @@ -19,7 +19,7 @@ namespace xgboost { */ class IntrusivePtrCell { private: - std::atomic count_; + std::atomic count_ {0}; template friend class IntrusivePtr; std::int32_t IncRef() noexcept { @@ -31,7 +31,7 @@ class IntrusivePtrCell { bool IsZero() const { return Count() == 0; } public: - IntrusivePtrCell() noexcept : count_{0} {} + IntrusivePtrCell() noexcept = default; int32_t Count() const { return count_.load(std::memory_order_relaxed); } }; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index dd87c601f978..04f218c1e0c4 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -126,9 +126,8 @@ class Predictor { * \param out_predt Prediction vector to be initialized. * \param model Tree model used for prediction. */ - virtual void InitOutPredictions(const MetaInfo &info, - HostDeviceVector *out_predt, - const gbm::GBTreeModel &model) const = 0; + void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_predt, + const gbm::GBTreeModel& model) const; /** * \brief Generate batch predictions for a given feature matrix. May use diff --git a/include/xgboost/task.h b/include/xgboost/task.h index 6430794c3ccd..69952d62c40d 100644 --- a/include/xgboost/task.h +++ b/include/xgboost/task.h @@ -33,7 +33,7 @@ struct ObjInfo { bool const_hess{false}; explicit ObjInfo(Task t) : task{t} {} - ObjInfo(Task t, bool khess) : const_hess{khess} {} + ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {} }; } // namespace xgboost #endif // XGBOOST_TASK_H_ diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 285b2e840f59..c739f9267af5 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable, Optional, List +from typing import Any, Tuple, Callable, Optional, List, Union import numpy as np @@ -138,14 +138,14 @@ def _is_numpy_array(data): return isinstance(data, (np.ndarray, np.matrix)) -def _ensure_np_dtype(data, dtype): +def _ensure_np_dtype(data, dtype) -> Tuple[np.ndarray, np.dtype]: if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]: data = data.astype(np.float32, copy=False) dtype = np.float32 return data, dtype -def _maybe_np_slice(data, dtype): +def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray: '''Handle numpy slice. This can be removed if we use __array_interface__. ''' try: @@ -852,23 +852,17 @@ def _validate_meta_shape(data: Any, name: str) -> None: def _meta_from_numpy( - data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p + data: np.ndarray, + field: str, + dtype: Optional[Union[np.dtype, str]], + handle: ctypes.c_void_p, ) -> None: - data = _maybe_np_slice(data, dtype) + data, dtype = _ensure_np_dtype(data, dtype) interface = data.__array_interface__ - assert interface.get('mask', None) is None, 'Masked array is not supported' - size = data.size - - c_type = _to_data_type(str(data.dtype), field) - ptr = interface['data'][0] - ptr = ctypes.c_void_p(ptr) - _check_call(_LIB.XGDMatrixSetDenseInfo( - handle, - c_str(field), - ptr, - c_bst_ulong(size), - c_type - )) + if interface.get("mask", None) is not None: + raise ValueError("Masked array is not supported.") + interface_str = _array_interface(data) + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface_str)) def _meta_from_list(data, field, dtype, handle): @@ -911,7 +905,9 @@ def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p): _meta_from_numpy(data, field, dtype, handle) -def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): +def dispatch_meta_backend( + matrix: DMatrix, data, name: str, dtype: Optional[Union[str, np.dtype]] = None +): '''Dispatch for meta info.''' handle = matrix.handle assert handle is not None diff --git a/src/common/common.cu b/src/common/common.cu index d30fbc0aeb88..4636a4cdcb7c 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -12,7 +12,8 @@ int AllVisibleGPUs() { // When compiled with CUDA but running on CPU only device, // cudaGetDeviceCount will fail. dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); - } catch(const dmlc::Error &except) { + } catch (const dmlc::Error &) { + cudaGetLastError(); // reset error. return 0; } return n_visgpus; diff --git a/src/data/data.cc b/src/data/data.cc index a6b76ee2a39c..3a2215180dce 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -3,6 +3,7 @@ * \file data.cc */ #include +#include #include #include "dmlc/io.h" @@ -12,10 +13,13 @@ #include "xgboost/logging.h" #include "xgboost/version_config.h" #include "xgboost/learner.h" +#include "xgboost/string_view.h" + #include "sparse_page_writer.h" #include "simple_dmatrix.h" #include "../common/io.h" +#include "../common/linalg_op.h" #include "../common/math.h" #include "../common/version.h" #include "../common/group_data.h" @@ -66,10 +70,22 @@ void SaveVectorField(dmlc::Stream* strm, const std::string& name, SaveVectorField(strm, name, type, shape, field.ConstHostVector()); } +template +void SaveTensorField(dmlc::Stream* strm, const std::string& name, xgboost::DataType type, + const xgboost::linalg::Tensor& field) { + strm->Write(name); + strm->Write(static_cast(type)); + strm->Write(false); // is_scalar=False + for (size_t i = 0; i < D; ++i) { + strm->Write(field.Shape(i)); + } + strm->Write(field.Data()->HostVector()); +} + template void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, T* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -91,7 +107,7 @@ void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, template void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, std::vector* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -124,6 +140,33 @@ void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, LoadVectorField(strm, expected_name, expected_type, &field->HostVector()); } +template +void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name, + xgboost::DataType expected_type, xgboost::linalg::Tensor* p_out) { + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; + std::string name; + xgboost::DataType type; + bool is_scalar; + CHECK(strm->Read(&name)) << invalid; + CHECK_EQ(name, expected_name) << invalid << " Expected field: " << expected_name + << ", got: " << name; + uint8_t type_val; + CHECK(strm->Read(&type_val)) << invalid; + type = static_cast(type_val); + CHECK(type == expected_type) << invalid + << "Expected field of type: " << static_cast(expected_type) + << ", " + << "got field type: " << static_cast(type); + CHECK(strm->Read(&is_scalar)) << invalid; + CHECK(!is_scalar) << invalid << "Expected field " << expected_name + << " to be a tensor; got a scalar"; + std::array shape; + for (size_t i = 0; i < D; ++i) { + CHECK(strm->Read(&(shape[i]))); + } + auto& field = p_out->Data()->HostVector(); + CHECK(strm->Read(&field)) << invalid; +} } // anonymous namespace namespace xgboost { @@ -136,25 +179,26 @@ void MetaInfo::Clear() { labels_.HostVector().clear(); group_ptr_.clear(); weights_.HostVector().clear(); - base_margin_.HostVector().clear(); + base_margin_ = decltype(base_margin_){}; } /* * Binary serialization format for MetaInfo: * - * | name | type | is_scalar | num_row | num_col | value | - * |--------------------+----------+-----------+---------+---------+-------------------------| - * | num_row | kUInt64 | True | NA | NA | ${num_row_} | - * | num_col | kUInt64 | True | NA | NA | ${num_col_} | - * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | - * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | - * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | - * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | - * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | - * | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} | - * | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} | - * | feature_names | kStr | False | ${size} | 1 | ${feature_names} | - * | feature_types | kStr | False | ${size} | 1 | ${feature_types} | + * | name | type | is_scalar | num_row | num_col | dim3 | value | + * |--------------------+----------+-----------+-------------+-------------+-------------+------------------------| + * | num_row | kUInt64 | True | NA | NA | NA | ${num_row_} | + * | num_col | kUInt64 | True | NA | NA | NA | ${num_col_} | + * | num_nonzero | kUInt64 | True | NA | NA | NA | ${num_nonzero_} | + * | labels | kFloat32 | False | ${size} | 1 | NA | ${labels_} | + * | group_ptr | kUInt32 | False | ${size} | 1 | NA | ${group_ptr_} | + * | weights | kFloat32 | False | ${size} | 1 | NA | ${weights_} | + * | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${Shape(2)} | ${base_margin_} | + * | labels_lower_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_lower_bound_} | + * | labels_upper_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_upper_bound_} | + * | feature_names | kStr | False | ${size} | 1 | NA | ${feature_names} | + * | feature_types | kStr | False | ${size} | 1 | NA | ${feature_types} | + * | feature_types | kFloat32 | False | ${size} | 1 | NA | ${feature_weights} | * * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. * Also notice the difference between the saved name and the name used in `SetInfo': @@ -175,8 +219,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {group_ptr_.size(), 1}, group_ptr_); ++field_cnt; SaveVectorField(fo, u8"weights", DataType::kFloat32, {weights_.Size(), 1}, weights_); ++field_cnt; - SaveVectorField(fo, u8"base_margin", DataType::kFloat32, - {base_margin_.Size(), 1}, base_margin_); ++field_cnt; + SaveTensorField(fo, u8"base_margin", DataType::kFloat32, base_margin_); ++field_cnt; SaveVectorField(fo, u8"labels_lower_bound", DataType::kFloat32, {labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt; SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, @@ -186,6 +229,9 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {feature_names.size(), 1}, feature_names); ++field_cnt; SaveVectorField(fo, u8"feature_types", DataType::kStr, {feature_type_names.size(), 1}, feature_type_names); ++field_cnt; + SaveVectorField(fo, u8"feature_weights", DataType::kFloat32, {feature_weights.Size(), 1}, + feature_weights); + ++field_cnt; CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } @@ -214,10 +260,14 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { auto major = std::get<0>(version); // MetaInfo is saved in `SparsePageSource'. So the version in MetaInfo represents the // version of DMatrix. - CHECK_EQ(major, 1) << "Binary DMatrix generated by XGBoost: " - << Version::String(version) << " is no longer supported. " - << "Please process and save your data in current version: " - << Version::String(Version::Self()) << " again."; + std::stringstream msg; + msg << "Binary DMatrix generated by XGBoost: " << Version::String(version) + << " is no longer supported. " + << "Please process and save your data in current version: " + << Version::String(Version::Self()) << " again."; + CHECK_EQ(major, 1) << msg.str(); + auto minor = std::get<1>(version); + CHECK_GE(minor, 6) << msg.str(); const uint64_t expected_num_field = kNumField; uint64_t num_field { 0 }; @@ -244,12 +294,13 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_); LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); - LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); + LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_); LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names); LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names); + LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); LoadFeatureType(feature_type_names, &feature_types.HostVector()); } @@ -292,10 +343,13 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; - size_t stride = this->base_margin_.Size() / this->num_row_; - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride); + auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); + out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1], margin.Shape()[2]); + size_t stride = margin.Stride(0); + out.base_margin_.Data()->HostVector() = + Gather(this->base_margin_.Data()->HostVector(), ridxs, stride); } else { - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); + out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs); } out.feature_weights.Resize(this->feature_weights.Size()); @@ -338,105 +392,179 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, return true; } -// macro to dispatch according to specified pointer types -#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ - switch (dtype) { \ - case xgboost::DataType::kFloat32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kDouble: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt64: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - default: LOG(FATAL) << "Unknown data type" << static_cast(dtype); \ - } \ - -void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - if (!std::strcmp(key, "label")) { - auto& labels = labels_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) { - return std::isnan(y) || std::isinf(y); - }); - CHECK(valid) << "Label contains NaN, infinity or a value too large."; - } else if (!std::strcmp(key, "weight")) { - auto& weights = weights_.HostVector(); - weights.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, weights.begin())); - auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) { - return w < 0 || std::isinf(w) || std::isnan(w); +namespace { +template +void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { + ArrayInterface array{arr_interface}; + if (array.n == 0) { + return; + } + CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + if (array.is_contiguous && array.type == ToDType::kType) { + // Handle contigious + p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + // set shape + std::copy(array.shape, array.shape + D, shape.data()); + // set data + data->Resize(array.n); + std::memcpy(data->HostPointer(), array.data, array.n * sizeof(T)); }); - CHECK(valid) << "Weights must be positive values."; - } else if (!std::strcmp(key, "base_margin")) { - auto& base_margin = base_margin_.HostVector(); - base_margin.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, base_margin.begin())); - } else if (!std::strcmp(key, "group")) { - group_ptr_.clear(); group_ptr_.resize(num + 1, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1)); - group_ptr_[0] = 0; - for (size_t i = 1; i < group_ptr_.size(); ++i) { - group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; + return; + } + p_out->Reshape(array.shape); + auto t = p_out->View(GenericParameter::kCpuId); + CHECK(t.Contiguous()); + // FIXME(jiamingy): Remove the use of this default thread. + linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); + }); +} +} // namespace + +void MetaInfo::SetInfo(StringView key, StringView interface_str) { + Json j_interface = Json::Load(interface_str); + bool is_cuda{false}; + if (IsA(j_interface)) { + auto const& array = get(j_interface); + CHECK_GE(array.size(), 0) << "Invalid " << key + << ", must have at least 1 column even if it's empty."; + auto const& first = get(array.front()); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } else { + auto const& first = get(j_interface); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } + + if (is_cuda) { + this->SetInfoFromCUDA(key, j_interface); + } else { + this->SetInfoFromHost(key, j_interface); + } +} + +void MetaInfo::SetInfoFromHost(StringView key, Json arr) { + // multi-dim float info + if (key == "base_margin") { + CopyTensorInfoImpl<3>(arr, &this->base_margin_); + // FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of + // input shape. This issue is CPU only since CUDA uses array interface from day 1. + // + // Python binding always understand the shape, so this condition should not occur for + // it. + if (this->num_row_ != 0 && this->base_margin_.Shape(0) != this->num_row_) { + // API functions that don't use array interface don't understand shape. + CHECK(this->base_margin_.Size() % this->num_row_ == 0) << "Incorrect size for base margin."; + size_t n_groups = this->base_margin_.Size() / this->num_row_; + this->base_margin_.Reshape(this->num_row_, n_groups); } + return; + } + // uint info + if (key == "group") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); + auto const& h_groups = t.Data()->HostVector(); + group_ptr_.clear(); + group_ptr_.resize(h_groups.size() + 1, 0); + group_ptr_[0] = 0; + std::partial_sum(h_groups.cbegin(), h_groups.cend(), group_ptr_.begin() + 1); data::ValidateQueryGroup(group_ptr_); - } else if (!std::strcmp(key, "qid")) { - std::vector query_ids(num, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, query_ids.begin())); + return; + } else if (key == "qid") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); bool non_dec = true; + auto const& query_ids = t.Data()->HostVector(); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] < query_ids[i-1]) { + if (query_ids[i] < query_ids[i - 1]) { non_dec = false; break; } } CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data."; - group_ptr_.clear(); group_ptr_.push_back(0); + group_ptr_.clear(); + group_ptr_.push_back(0); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] != query_ids[i-1]) { + if (query_ids[i] != query_ids[i - 1]) { group_ptr_.push_back(i); } } if (group_ptr_.back() != query_ids.size()) { group_ptr_.push_back(query_ids.size()); } - } else if (!std::strcmp(key, "label_lower_bound")) { - auto& labels = labels_lower_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "label_upper_bound")) { - auto& labels = labels_upper_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "feature_weights")) { - auto &h_feature_weights = feature_weights.HostVector(); - h_feature_weights.resize(num); - DISPATCH_CONST_PTR( - dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin())); + data::ValidateQueryGroup(group_ptr_); + return; + } + // float info + linalg::Tensor t; + CopyTensorInfoImpl<1>(arr, &t); + if (key == "label") { + this->labels_ = std::move(*t.Data()); + auto const& h_labels = labels_.ConstHostVector(); + auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; + } else if (key == "weight") { + this->weights_ = std::move(*t.Data()); + auto const& h_weights = this->weights_.ConstHostVector(); + auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(), + [](float w) { return w < 0 || std::isinf(w) || std::isnan(w); }); + CHECK(valid) << "Weights must be positive values."; + } else if (key == "label_lower_bound") { + this->labels_lower_bound_ = std::move(*t.Data()); + } else if (key == "label_upper_bound") { + this->labels_upper_bound_ = std::move(*t.Data()); + } else if (key == "feature_weights") { + this->feature_weights = std::move(*t.Data()); + auto const& h_feature_weights = feature_weights.ConstHostVector(); bool valid = - std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { return w < 0; }); + std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; } else { LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } -void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, - const void **out_dptr) const { +void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { + auto proc = [&](auto cast_d_ptr) { + using T = std::remove_pointer_t; + auto t = + linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, GenericParameter::kCpuId); + CHECK(t.Contiguous()); + Json interface { t.ArrayInterface() }; + assert(ArrayInterface<1>{interface}.is_contiguous); + return interface; + }; + // Legacy code using XGBoost dtype, which is a small subset of array interface types. + switch (dtype) { + case xgboost::DataType::kFloat32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kDouble: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt64: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + default: + LOG(FATAL) << "Unknown data type" << static_cast(dtype); + } +} + +void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, + const void** out_dptr) const { if (dtype == DataType::kFloat32) { const std::vector* vec = nullptr; if (!std::strcmp(key, "label")) { @@ -444,7 +572,7 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, } else if (!std::strcmp(key, "weight")) { vec = &this->weights_.HostVector(); } else if (!std::strcmp(key, "base_margin")) { - vec = &this->base_margin_.HostVector(); + vec = &this->base_margin_.Data()->HostVector(); } else if (!std::strcmp(key, "label_lower_bound")) { vec = &this->labels_lower_bound_.HostVector(); } else if (!std::strcmp(key, "label_upper_bound")) { @@ -533,8 +661,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx()); this->labels_upper_bound_.Extend(that.labels_upper_bound_); - this->base_margin_.SetDevice(that.base_margin_.DeviceIdx()); - this->base_margin_.Extend(that.base_margin_); + linalg::Stack(&this->base_margin_, that.base_margin_); if (this->group_ptr_.size() == 0) { this->group_ptr_ = that.group_ptr_; @@ -617,14 +744,12 @@ void MetaInfo::Validate(int32_t device) const { if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(base_margin_); + check_device(*base_margin_.Data()); } } #if !defined(XGBOOST_USE_CUDA) -void MetaInfo::SetInfo(StringView key, std::string const& interface_str) { - common::AssertGPUSupport(); -} +void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) using DMatrixThreadLocal = @@ -778,10 +903,10 @@ DMatrix* DMatrix::Load(const std::string& uri, LOG(CONSOLE) << info.group_ptr_.size() - 1 << " groups are loaded from " << fname << ".group"; } - if (MetaTryLoadFloatInfo - (fname + ".base_margin", &info.base_margin_.HostVector()) && !silent) { - LOG(CONSOLE) << info.base_margin_.Size() - << " base_margin are loaded from " << fname << ".base_margin"; + if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) && + !silent) { + LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname + << ".base_margin"; } if (MetaTryLoadFloatInfo (fname + ".weight", &info.weights_.HostVector()) && !silent) { diff --git a/src/data/data.cu b/src/data/data.cu index 8e13db9ce751..6d85a85e261b 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -114,14 +114,10 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_ } } // namespace -void MetaInfo::SetInfo(StringView key, std::string const& interface_str) { - Json array = Json::Load(StringView{interface_str}); +void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { // multi-dim float info if (key == "base_margin") { - // FIXME(jiamingy): This is temporary until #7405 can be fully merged - linalg::Tensor t; - CopyTensorInfoImpl(array, &t); - base_margin_ = std::move(*t.Data()); + CopyTensorInfoImpl(array, &base_margin_); return; } // uint info diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 44a8a3f8fe7c..e83559d3958a 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -137,9 +137,10 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - auto& base_margin = info_.base_margin_.HostVector(); - base_margin.insert(base_margin.end(), batch.BaseMargin(), - batch.BaseMargin() + batch.Size()); + info_.base_margin_ = linalg::Tensor{batch.BaseMargin(), + batch.BaseMargin() + batch.Size(), + {batch.Size()}, + GenericParameter::kCpuId}; } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index e5f2a457214c..e5f5916211cc 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -178,7 +178,7 @@ class GBLinear : public GradientBooster { unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override { model_.LazyInitModel(); LinearCheckLayer(layer_begin, layer_end); - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); const int ngroup = model_.learner_model_param->num_output_group; const size_t ncolumns = model_.learner_model_param->num_feature + 1; // allocate space for (#features + bias) times #groups times #rows @@ -203,9 +203,9 @@ class GBLinear : public GradientBooster { p_contribs[ins.index] = ins.fvalue * model_[ins.index][gid]; } // add base margin to BIAS - p_contribs[ncolumns - 1] = model_.Bias()[gid] + - ((base_margin.size() != 0) ? base_margin[row_idx * ngroup + gid] : - learner_model_param_->base_score); + p_contribs[ncolumns - 1] = + model_.Bias()[gid] + ((base_margin.Size() != 0) ? base_margin(row_idx, gid) + : learner_model_param_->base_score); } }); } @@ -270,7 +270,7 @@ class GBLinear : public GradientBooster { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); std::vector &preds = *out_preds; - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); @@ -280,16 +280,15 @@ class GBLinear : public GradientBooster { // k is number of group // parallel over local batch const auto nsize = static_cast(batch.Size()); - if (base_margin.size() != 0) { - CHECK_EQ(base_margin.size(), nsize * ngroup); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Size(), nsize * ngroup); } common::ParallelFor(nsize, [&](omp_ulong i) { const size_t ridx = page.base_rowid + i; // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float margin = - (base_margin.size() != 0) ? - base_margin[ridx * ngroup + gid] : learner_model_param_->base_score; + float margin = + (base_margin.Size() != 0) ? base_margin(ridx, gid) : learner_model_param_->base_score; this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); } }); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index d581f64a1d56..92797235d16b 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -282,27 +282,6 @@ class CPUPredictor : public Predictor { } } - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n = model.learner_model_param->num_output_group * info.num_row_; - const auto& base_margin = info.base_margin_.HostVector(); - out_preds->Resize(n); - std::vector& out_preds_h = out_preds->HostVector(); - if (base_margin.empty()) { - std::fill(out_preds_h.begin(), out_preds_h.end(), - model.learner_model_param->base_score); - } else { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); - } - } - public: explicit CPUPredictor(GenericParameter const* generic_param) : Predictor::Predictor{generic_param} {} @@ -456,7 +435,7 @@ class CPUPredictor : public Predictor { common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); - const std::vector& base_margin = info.base_margin_.HostVector(); + auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); // start collecting the contributions for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); @@ -496,8 +475,9 @@ class CPUPredictor : public Predictor { } feats.Drop(page[i]); // add base margin to BIAS - if (base_margin.size() != 0) { - p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), ngroup); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); } else { p_contribs[ncolumns - 1] += model.learner_model_param->base_score; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 51674237e973..71724f95236c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -855,7 +855,7 @@ class GPUPredictor : public xgboost::Predictor { } // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; dh::LaunchN( p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, @@ -914,7 +914,7 @@ class GPUPredictor : public xgboost::Predictor { } // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; size_t n_features = model.learner_model_param->num_feature; dh::LaunchN( @@ -928,27 +928,6 @@ class GPUPredictor : public xgboost::Predictor { }); } - protected: - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - size_t n_classes = model.learner_model_param->num_output_group; - size_t n = n_classes * info.num_row_; - const HostDeviceVector& base_margin = info.base_margin_; - out_preds->SetDevice(generic_param_->gpu_id); - out_preds->Resize(n); - if (base_margin.Size() != 0) { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.Size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - out_preds->Copy(base_margin); - } else { - out_preds->Fill(model.learner_model_param->base_score); - } - } - void PredictInstance(const SparsePage::Inst&, std::vector*, const gbm::GBTreeModel&, unsigned) const override { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 9aa18b19ce30..b86474184ccc 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 by Contributors + * Copyright 2017-2021 by Contributors */ #include #include @@ -8,6 +8,8 @@ #include "xgboost/data.h" #include "xgboost/generic_parameters.h" +#include "../gbm/gbtree.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc @@ -58,6 +60,38 @@ Predictor* Predictor::Create( auto p_predictor = (e->body)(generic_param); return p_predictor; } + +void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n_samples, + bst_group_t n_groups) { + // FIXME: Bindings other than Python doesn't have shape. + std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) + + ", " + std::to_string(n_groups) + ")"}; + CHECK_EQ(margin.Shape(0), n_samples) << expected; + CHECK_EQ(margin.Shape(1), n_groups) << expected; +} + +void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model) const { + CHECK_NE(model.learner_model_param->num_output_group, 0); + size_t n_classes = model.learner_model_param->num_output_group; + size_t n = n_classes * info.num_row_; + const HostDeviceVector* base_margin = info.base_margin_.Data(); + if (generic_param_->gpu_id >= 0) { + out_preds->SetDevice(generic_param_->gpu_id); + } + if (base_margin->Size() != 0) { + out_preds->Resize(n); + ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); + out_preds->Copy(*base_margin); + } else { + if (out_preds->Empty()) { + out_preds->Resize(n, model.learner_model_param->base_score); + } else { + out_preds->Resize(n); + out_preds->Fill(model.learner_model_param->base_score); + } + } +} } // namespace xgboost namespace xgboost { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index c7fc4972d3ee..49952d202706 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -57,7 +57,7 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, Json(Integer(reinterpret_cast(x.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; - array_interface["version"] = Integer(static_cast(1)); + array_interface["version"] = 3; array_interface["typestr"] = String(" #include #include @@ -122,7 +124,10 @@ TEST(MetaInfo, SaveLoadBinary) { EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); - EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector()); + + auto orig_margin = info.base_margin_.View(xgboost::GenericParameter::kCpuId); + auto read_margin = inforead.base_margin_.View(xgboost::GenericParameter::kCpuId); + EXPECT_TRUE(std::equal(orig_margin.cbegin(), orig_margin.cend(), read_margin.cbegin())); EXPECT_EQ(inforead.feature_type_names.size(), kCols); EXPECT_EQ(inforead.feature_types.Size(), kCols); @@ -254,10 +259,10 @@ TEST(MetaInfo, Validate) { xgboost::HostDeviceVector d_groups{groups}; d_groups.SetDevice(0); d_groups.DevicePointer(); // pull to device - auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1); - std::string arr_interface_str; - xgboost::Json::Dump(arr_interface, &arr_interface_str); - EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error); + std::string arr_interface_str{ + xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0) + .ArrayInterfaceStr()}; + EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } @@ -292,3 +297,7 @@ TEST(MetaInfo, HostExtend) { ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i); } } + +namespace xgboost { +TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); } +} // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 205844a5e961..bbb78e7924e7 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -3,10 +3,13 @@ #include #include #include +#include #include #include "test_array_interface.h" #include "../../../src/common/device_helpers.cuh" +#include "test_metainfo.h" + namespace xgboost { template @@ -23,7 +26,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); auto p_d_data = d_data.data().get(); @@ -31,6 +34,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons Json(Integer(reinterpret_cast(p_d_data))), Json(Boolean(false))}; column["data"] = j_data; + column["stream"] = nullptr; Json array(std::vector{column}); std::string str; @@ -49,6 +53,7 @@ TEST(MetaInfo, FromInterface) { info.SetInfo("label", str.c_str()); auto const& h_label = info.labels_.HostVector(); + ASSERT_EQ(h_label.size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { ASSERT_EQ(h_label[i], d_data[i]); } @@ -60,9 +65,10 @@ TEST(MetaInfo, FromInterface) { } info.SetInfo("base_margin", str.c_str()); - auto const& h_base_margin = info.base_margin_.HostVector(); + auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId); + ASSERT_EQ(h_base_margin.Size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { - ASSERT_EQ(h_base_margin[i], d_data[i]); + ASSERT_EQ(h_base_margin(i), d_data[i]); } thrust::device_vector d_group_data; @@ -76,6 +82,10 @@ TEST(MetaInfo, FromInterface) { EXPECT_EQ(info.group_ptr_, expected_group_ptr); } +TEST(MetaInfo, GPUStridedData) { + TestMetaInfoStridedData(0); +} + TEST(MetaInfo, Group) { cudaSetDevice(0); MetaInfo info; diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h new file mode 100644 index 000000000000..67da633d4be5 --- /dev/null +++ b/tests/cpp/data/test_metainfo.h @@ -0,0 +1,82 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#include +#include +#include +#include + +#include +#include "../../../src/data/array_interface.h" +#include "../../../src/common/linalg_op.h" + +namespace xgboost { +inline void TestMetaInfoStridedData(int32_t device) { + MetaInfo info; + { + // label + HostDeviceVector labels; + labels.Resize(64); + auto& h_labels = labels.HostVector(); + std::iota(h_labels.begin(), h_labels.end(), 0.0f); + bool is_gpu = device >= 0; + if (is_gpu) { + labels.SetDevice(0); + } + + auto t = linalg::TensorView{ + is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device}; + auto s = t.Slice(linalg::All(), 0); + + auto str = s.ArrayInterfaceStr(); + ASSERT_EQ(s.Size(), 32); + + info.SetInfo("label", StringView{str}); + auto const& h_result = info.labels_.HostVector(); + ASSERT_EQ(h_result.size(), 32); + + for (auto v : h_result) { + ASSERT_EQ(static_cast(v) % 2, 0); + } + } + { + // qid + linalg::Tensor qid; + qid.Reshape(32, 2); + auto& h_qid = qid.Data()->HostVector(); + std::iota(h_qid.begin(), h_qid.end(), 0); + auto s = qid.View(device).Slice(linalg::All(), 0); + auto str = s.ArrayInterfaceStr(); + info.SetInfo("qid", StringView{str}); + auto const& h_result = info.group_ptr_; + ASSERT_EQ(h_result.size(), s.Size() + 1); + } + { + // base margin + linalg::Tensor base_margin; + base_margin.Reshape(4, 3, 2, 3); + auto& h_margin = base_margin.Data()->HostVector(); + std::iota(h_margin.begin(), h_margin.end(), 0.0); + auto t_margin = base_margin.View(device).Slice(linalg::All(), linalg::All(), 0, linalg::All()); + ASSERT_EQ(t_margin.Shape().size(), 3); + + info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()}); + auto const& h_result = info.base_margin_.View(-1); + ASSERT_EQ(h_result.Shape().size(), 3); + auto in_margin = base_margin.View(-1); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + auto tup = linalg::UnravelIndex(i, h_result.Shape()); + auto i0 = std::get<0>(tup); + auto i1 = std::get<1>(tup); + auto i2 = std::get<2>(tup); + // Sliced at 3^th dimension. + auto v_1 = in_margin(i0, i1, 0, i2); + CHECK_EQ(v_0, v_1); + return v_0; + }); + } +} +} // namespace xgboost +#endif // XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index f777b00e244c..c25e877079d2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -253,8 +253,8 @@ TEST(SimpleDMatrix, Slice) { std::iota(lower.begin(), lower.end(), 0.0f); std::iota(upper.begin(), upper.end(), 1.0f); - auto& margin = p_m->Info().base_margin_.HostVector(); - margin.resize(kRows * kClasses); + auto& margin = p_m->Info().base_margin_; + margin = linalg::Tensor{{kRows, kClasses}, GenericParameter::kCpuId}; std::array ridxs {1, 3, 5}; std::unique_ptr out { p_m->Slice(ridxs) }; @@ -284,10 +284,10 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx), out->Info().weights_.HostVector().at(i)); - auto& out_margin = out->Info().base_margin_.HostVector(); + auto out_margin = out->Info().base_margin_.View(GenericParameter::kCpuId); + auto in_margin = margin.View(GenericParameter::kCpuId); for (size_t j = 0; j < kClasses; ++j) { - auto in_beg = ridx * kClasses; - ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j)); + ASSERT_EQ(out_margin(i, j), in_margin(ridx, j)); } } } diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index d74f5b150087..19f13b1fddfd 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -122,13 +122,13 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) { col["data"] = j_data; std::vector j_shape{Json(Integer(static_cast(kRows)))}; col["shape"] = Array(j_shape); - col["version"] = Integer(static_cast(1)); + col["version"] = 3; col["typestr"] = String("(1)); + j_mask["version"] = 3; auto& mask_storage = column_bitfields[i]; mask_storage.resize(16); // 16 bytes @@ -220,7 +220,7 @@ TEST(SimpleCSRSource, FromColumnarSparse) { for (size_t c = 0; c < kCols; ++c) { auto& column = j_columns[c]; column = Object(); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(" j_data { @@ -229,12 +229,12 @@ TEST(SimpleCSRSource, FromColumnarSparse) { column["data"] = j_data; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String("(1)); + j_mask["version"] = 3; j_mask["data"] = std::vector{ Json(Integer(reinterpret_cast(column_bitfields[c].data().get()))), Json(Boolean(false))}; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index db4454c9eb1e..0906d9ed87df 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -228,6 +228,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( if (device_ >= 0) { array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer() + offset)); + array_interface["stream"] = Null{}; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->HostPointer() + offset)); @@ -240,7 +241,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( array_interface["shape"][1] = cols_; array_interface["typestr"] = String(" *storage, size_t rows, size_t cols) { if (storage->DeviceCanRead()) { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstDevicePointer())); + array_interface["stream"] = nullptr; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstHostPointer())); @@ -200,7 +201,7 @@ Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { char t = linalg::detail::ArrayInterfaceHandler::TypeChar(); array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T))); - array_interface["version"] = 1; + array_interface["version"] = 3; return array_interface; } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index ad1083f9161b..b36df742da4f 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -108,7 +108,9 @@ TEST(GPUPredictor, ExternalMemoryTest) { dmats.push_back(CreateSparsePageDMatrix(8000)); for (const auto& dmat: dmats) { - dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); + dmat->Info().base_margin_ = + linalg::Tensor{{dmat->Info().num_row_, static_cast(n_classes)}, 0}; + dmat->Info().base_margin_.Data()->Fill(0.5); PredictionCacheEntry out_predictions; gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 1b5ad266d3d9..6cd026d19bc3 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -17,7 +17,7 @@ def set_base_margin_info(DType, DMatrixT, tm: str): rng = np.random.default_rng() - X = DType(rng.normal(0, 1.0, size=100).reshape(50, 2)) + X = DType(rng.normal(0, 1.0, size=100).astype(np.float32).reshape(50, 2)) if hasattr(X, "iloc"): y = X.iloc[:, 0] else: @@ -29,8 +29,35 @@ def set_base_margin_info(DType, DMatrixT, tm: str): with pytest.raises(ValueError, match=r".*base_margin.*"): xgb.train({"tree_method": tm}, Xy) - # FIXME(jiamingy): Currently the metainfo has no concept of shape. If you pass a - # base_margin with shape (n_classes, n_samples) to XGBoost the result is undefined. + if not hasattr(X, "iloc"): + # column major matrix + got = DType(Xy.get_base_margin().reshape(50, 2)) + assert (got == base_margin).all() + + assert base_margin.T.flags.c_contiguous is False + assert base_margin.T.flags.f_contiguous is True + Xy.set_info(base_margin=base_margin.T) + got = DType(Xy.get_base_margin().reshape(2, 50)) + assert (got == base_margin.T).all() + + # Row vs col vec. + base_margin = y + Xy.set_base_margin(base_margin) + bm_col = Xy.get_base_margin() + Xy.set_base_margin(base_margin.reshape(1, base_margin.size)) + bm_row = Xy.get_base_margin() + assert (bm_row == bm_col).all() + + # type + base_margin = base_margin.astype(np.float64) + Xy.set_base_margin(base_margin) + bm_f64 = Xy.get_base_margin() + assert (bm_f64 == bm_col).all() + + # too many dimensions + base_margin = X.reshape(2, 5, 2, 5) + with pytest.raises(ValueError, match=r".*base_margin.*"): + Xy.set_base_margin(base_margin) class TestDMatrix: @@ -141,6 +168,7 @@ def test_slice(self): # base margin is per-class in multi-class classifier base_margin = rng.randn(100, 3).astype(np.float32) d.set_base_margin(base_margin) + np.testing.assert_allclose(d.get_base_margin().reshape(100, 3), base_margin) ridxs = [1, 2, 3, 4, 5, 6] sliced = d.slice(ridxs) @@ -154,7 +182,7 @@ def test_slice(self): # Slicing a DMatrix results into a DMatrix that's equivalent to a DMatrix that's # constructed from the corresponding NumPy slice d2 = xgb.DMatrix(X[1:7, :], y[1:7]) - d2.set_base_margin(base_margin[1:7, :].flatten()) + d2.set_base_margin(base_margin[1:7, :]) eval_res = {} _ = xgb.train( {'num_class': 3, 'objective': 'multi:softprob', @@ -280,7 +308,7 @@ def test_feature_weights(self): m.set_info(feature_weights=fw) np.testing.assert_allclose(fw, m.get_float_info('feature_weights')) # Handle empty - m.set_info(feature_weights=np.empty((0, 0))) + m.set_info(feature_weights=np.empty((0, ))) assert m.get_float_info('feature_weights').shape[0] == 0