From b1b1cde7cc8177efddf86dfba00ad326ecb42f8a Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 26 Oct 2021 21:13:34 +0800 Subject: [PATCH] Implement typed and type erased tensor. * Use typed tensor for storing meta info like base margin. * Extend the array interface handler to multi-dim. Implement a general array view. * Replace existing matrix and vector view. lint. Remove const too. Doc/Test. Include. Use it in AUC. Win build. Use int32_t. Use integral. force the same type. Use constexpr for old nvcc. Test for empty tensor. Rename to view. Format. Better document and perf. Address reviewer's comment. tidy. Implement a general array view. * Replace existing matrix and vector view. lint. Remove const too. Doc/Test. Include. Use it in AUC. Win build. Use int32_t. Use integral. force the same type. Use constexpr for old nvcc. Test for empty tensor. Rename to view. Format. Prototype. Move string view. Compile on CPU. Some fixes for GPU compilation. Array interface. Use it in file iter. Cleanup. itemsize. Documents. Cache the pointer. port. cuda compilation. Start working on ari. Add clang-format config. (#7383) Generated using `clang-format -style=google -dump-config > .clang-format`, with column width changed from 80 to 100 to be consistent with existing cpplint check. Define shape and stride. Convert some uses. Prototype for copy tensor info. proto unravel Indexer. Unravel. Cleanup. Fixes. fixe.s WAR. as column vector. Convert some code. some more. some more. Compile. Ensure column vector from the beginning. IO. Add code comments. Test for trivial dimension. Start CPU implementation. Refactor. Dispatch. Compile. Cleanup slice and index for array interface. Add document about handling user input. Validate the shape of base margin. Cleanup out prediction. Use it in Python. Optimization. Cleanup. Define unravel index as an interface. Dispatch DType. Err. By pass the iterator. Cleanup old code. comment. Cleanup. Remove duplicated function. Add contiguous. Typo. Fix copying for group. comment. Fix CSR. Fix empty dimensions. Use linalg for utilities. Fix test. Fix streams. Basic elementwise kernel. Fixes. fix dtype. Fix index. Comment. Cleanup. popcnt implementation. Move to compile time. Fix. Fix. Move to compile time. Forward. Forward. Include. Lint. Reintroduce the checks. Fix long double. Remove check. Comment. Remove macro. Some changes in jvm package. Include. Ignore 0 shape. Fix. Restore bound check for now. Fix slice. Fix test feature weights. Stricter requirements. Reshape. Require v3 interface. Use CUDA ptr attr. Fix test. Invalid stream. tidy. Fix versions. test stack. CPU only. Simplifies. Reverse version. Force c++14 for msvc. Lint. Remove imports. Revert cmake changes. Reset. Check. Unify device initialization. Tests. Just bypass the error. Fix. Fix. test qid. Fix test. Cleanup. lint. Fix. Restore the heuristic. Update metainfo binary. Fix typo. Restore the optimization. Cleanup sklearn tests. Cleanup dask test. Tidy. Tidy is happy. Polish. Unittest. Typo. --- doc/contrib/coding_guide.rst | 46 ++ include/xgboost/c_api.h | 31 +- include/xgboost/data.h | 18 +- include/xgboost/intrusive_ptr.h | 4 +- include/xgboost/predictor.h | 2 +- .../xgboost4j-gpu/src/native/xgboost4j-gpu.cu | 30 +- python-package/xgboost/data.py | 34 +- src/common/common.cu | 3 +- src/data/adapter.h | 80 ++- src/data/array_interface.cu | 53 +- src/data/array_interface.h | 485 +++++++++++------- src/data/data.cc | 357 ++++++++----- src/data/data.cu | 182 +++---- src/data/device_adapter.cuh | 47 +- src/data/file_iterator.h | 15 +- src/data/simple_dmatrix.cc | 7 +- src/data/validation.h | 40 ++ src/gbm/gblinear.cc | 19 +- src/predictor/cpu_predictor.cc | 28 +- src/predictor/gpu_predictor.cu | 29 +- src/predictor/predictor.cc | 31 +- tests/cpp/common/test_hist_util.h | 2 +- tests/cpp/data/test_adapter.cc | 7 +- tests/cpp/data/test_array_interface.cc | 55 +- tests/cpp/data/test_array_interface.cu | 15 +- tests/cpp/data/test_array_interface.h | 10 +- .../cpp/data/test_iterative_device_dmatrix.cu | 4 +- tests/cpp/data/test_metainfo.cc | 21 +- tests/cpp/data/test_metainfo.cu | 19 +- tests/cpp/data/test_metainfo.h | 82 +++ tests/cpp/data/test_simple_dmatrix.cc | 10 +- tests/cpp/data/test_simple_dmatrix.cu | 10 +- tests/cpp/helpers.cc | 3 +- tests/cpp/helpers.h | 5 +- tests/cpp/predictor/test_gpu_predictor.cu | 4 +- tests/python/test_dmatrix.py | 5 +- 36 files changed, 1120 insertions(+), 673 deletions(-) create mode 100644 src/data/validation.h create mode 100644 tests/cpp/data/test_metainfo.h diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst index 6d407ba129e9..b4880803cb0c 100644 --- a/doc/contrib/coding_guide.rst +++ b/doc/contrib/coding_guide.rst @@ -134,3 +134,49 @@ Similarly, if you want to exclude C++ source from linting: cd /path/to/xgboost/ python3 tests/ci_build/tidy.py --cpp=0 +********************************** +Guide for handling user input data +********************************** + +This is an in-comprehensive guide for handling user input data. XGBoost has wide verity +of native supported data structures, mostly come from higher level language bindings. The +inputs ranges from basic contiguous 1 dimension memory buffer to more sophisticated data +structures like columnar data with validity mask. Raw input data can be used in 2 places, +firstly it's the construction of various ``DMatrix``, secondly it's the in-place +prediction. For plain memory buffer, there's not much to discuss since it's just a +pointer with a size. But for general n-dimension array and columnar data, there are many +subtleties. XGBoost has 3 different data structures for handling optionally masked arrays +(tensors), for consuming user inputs ``ArrayInterface`` should be chosen. There are many +existing functions that accept only plain pointer due to legacy reasons (XGBoost started +as a much simpler library and didn't care about memory usage that much back then). The +``ArrayInterface`` is a in memory representation of ``__array_interface__`` protocol +defined by numpy or the ``__cuda_array_interface__`` defined by numba. Following is a +check list of things to have in mind when accepting related user inputs: + +- [ ] Is it strided? (identified by the ``strides`` field) +- [ ] If it's a vector, is it row vector or column vector? (Identified by both ``shape`` + and ``strides``). +- [ ] Is the data type supported? Half type and 128 integer types should be converted + before going into XGBoost. +- [ ] Does it have higher than 1 dimension? (identified by ``shape`` field) +- [ ] Are some of dimensions trivial? (shape[dim] <= 1) +- [ ] Does it have mask? (identified by ``mask`` field) +- [ ] Can the mask be broadcasted? (unsupported at the moment) +- [ ] Is it on CUDA memory? (identified by ``data`` field, and optionally ``stream``) + +Most of the checks are handled by the ``ArrayInterface`` during construction, except for +the data type issue since it doesn't know how to cast such pointers with C builtin types. +But for safety reason one should still try to write related tests for the all items. The +data type issue should be taken care of in language binding for each of the specific data +input. For single-chunk columnar format, it's just a masked array for each column so it +should be treated uniformly as normal array. For input predictor ``X``, we have adapters +for each type of input. Some are composition of the others. For instance, CSR matrix has 3 +potentially strided arrays for ``indptr``, ``indices`` and ``values``. No assumption +should be made to these components (all the check boxes should be considered). Slicing row +of CSR matrix should calculate the offset of each field based on respective strides. + +For meta info like labels, which is growing both in size and complexity, we accept only +masked array at the moment (no specialized adapter). One should be careful about the +input data shape. For base margin it can be 2 dim or higher if we have multiple targets in +the future. The getters in ``DMatrix`` returns only 1 dimension flatten vectors at the +moment, which can be improved in the future when it's needed. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 8d929593c1a4..9d657872e26b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -249,7 +249,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const* json_config, DMatrixHandle *out); -/* +/** * ========================== Begin data callback APIs ========================= * * Short notes for data callback @@ -258,9 +258,9 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated * input, and concatenate them into 1 final big CSR. The related functions are: * - * - XGBCallbackSetData - * - XGBCallbackDataIterNext - * - XGDMatrixCreateFromDataIter + * - \ref XGBCallbackSetData + * - \ref XGBCallbackDataIterNext + * - \ref XGDMatrixCreateFromDataIter * * Another set is used by external data iterator. It accept foreign data iterators as * callbacks. There are 2 different senarios where users might want to pass in callbacks @@ -276,17 +276,17 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * Related functions are: * * # Factory functions - * - `XGDMatrixCreateFromCallback` for external memory - * - `XGDeviceQuantileDMatrixCreateFromCallback` for quantile DMatrix + * - \ref XGDMatrixCreateFromCallback for external memory + * - \ref XGDeviceQuantileDMatrixCreateFromCallback for quantile DMatrix * * # Proxy that callers can use to pass data to XGBoost - * - XGProxyDMatrixCreate - * - XGDMatrixCallbackNext - * - DataIterResetCallback - * - XGProxyDMatrixSetDataCudaArrayInterface - * - XGProxyDMatrixSetDataCudaColumnar - * - XGProxyDMatrixSetDataDense - * - XGProxyDMatrixSetDataCSR + * - \ref XGProxyDMatrixCreate + * - \ref XGDMatrixCallbackNext + * - \ref DataIterResetCallback + * - \ref XGProxyDMatrixSetDataCudaArrayInterface + * - \ref XGProxyDMatrixSetDataCudaColumnar + * - \ref XGProxyDMatrixSetDataDense + * - \ref XGProxyDMatrixSetDataCSR * - ... (data setters) */ @@ -411,7 +411,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. * - nthread (optional): Number of threads used for initializing DMatrix. * - * \param out The created external memory DMatrix + * \param[out] out The created external memory DMatrix * * \return 0 when success, -1 when failure happens */ @@ -605,7 +605,8 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, * char const* feat_names [] {"feat_0", "feat_1"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_name", feat_names, 2); * - * // i for integer, q for quantitive. Similarly "int" and "float" are also recognized. + * // i for integer, q for quantitive, c for categorical. Similarly "int" and "float" + * // are also recognized. * char const* feat_types [] {"i", "q"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, 2); * diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 1678f4b1f4f1..90cbc1373713 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -11,12 +11,14 @@ #include #include #include -#include #include +#include +#include +#include +#include #include #include -#include #include #include #include @@ -45,7 +47,7 @@ enum class FeatureType : uint8_t { class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 11; + static constexpr uint64_t kNumField = 12; /*! \brief number of rows in the data */ uint64_t num_row_{0}; // NOLINT @@ -67,7 +69,7 @@ class MetaInfo { * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from. */ - HostDeviceVector base_margin_; // NOLINT + linalg::Tensor base_margin_; // NOLINT /*! * \brief lower bound of the label, to be used for survival analysis (censored regression) */ @@ -157,7 +159,8 @@ class MetaInfo { * * Right now only 1 column is permitted. */ - void SetInfo(const char* key, std::string const& interface_str); + + void SetInfo(StringView key, StringView interface_str); void GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const; @@ -179,6 +182,9 @@ class MetaInfo { void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); private: + void SetInfoFromHost(StringView key, Json arr); + void SetInfoFromCUDA(StringView key, Json arr); + /*! \brief argsort of labels */ mutable std::vector label_order_cache_; }; @@ -477,7 +483,7 @@ class DMatrix { this->Info().SetInfo(key, dptr, dtype, num); } virtual void SetInfo(const char* key, std::string const& interface_str) { - this->Info().SetInfo(key, interface_str); + this->Info().SetInfo(key, StringView{interface_str}); } /*! \brief meta information of the dataset */ virtual const MetaInfo& Info() const = 0; diff --git a/include/xgboost/intrusive_ptr.h b/include/xgboost/intrusive_ptr.h index 1c58704c4ca3..df7ae30213d6 100644 --- a/include/xgboost/intrusive_ptr.h +++ b/include/xgboost/intrusive_ptr.h @@ -19,7 +19,7 @@ namespace xgboost { */ class IntrusivePtrCell { private: - std::atomic count_; + std::atomic count_ {0}; template friend class IntrusivePtr; std::int32_t IncRef() noexcept { @@ -31,7 +31,7 @@ class IntrusivePtrCell { bool IsZero() const { return Count() == 0; } public: - IntrusivePtrCell() noexcept : count_{0} {} + IntrusivePtrCell() noexcept = default; int32_t Count() const { return count_.load(std::memory_order_relaxed); } }; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index dd87c601f978..859669a47ada 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -128,7 +128,7 @@ class Predictor { */ virtual void InitOutPredictions(const MetaInfo &info, HostDeviceVector *out_predt, - const gbm::GBTreeModel &model) const = 0; + const gbm::GBTreeModel &model) const; /** * \brief Generate batch predictions for a given feature matrix. May use diff --git a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu index de7a1fc41495..4ecf8b0f1da1 100644 --- a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu +++ b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu @@ -35,13 +35,12 @@ template T CheckJvmCall(T const &v, JNIEnv *jenv) { } template -void CopyColumnMask(xgboost::ArrayInterface const &interface, +void CopyColumnMask(xgboost::ArrayInterface<1> const &interface, std::vector const &columns, cudaMemcpyKind kind, size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) { auto &mask = *p_mask; auto &out = *p_out; - auto size = sizeof(typename VCont::value_type) * interface.num_rows * - interface.num_cols; + auto size = sizeof(typename VCont::value_type) * interface.n; mask.resize(size); CHECK(RawPtr(mask)); CHECK(size); @@ -67,11 +66,11 @@ void CopyColumnMask(xgboost::ArrayInterface const &interface, LOG(FATAL) << "Invalid shape of mask"; } out["mask"]["typestr"] = String(" -void CopyInterface(std::vector &interface_arr, +void CopyInterface(std::vector> &interface_arr, std::vector const &columns, cudaMemcpyKind kind, std::vector *p_data, std::vector *p_mask, std::vector *p_out, cudaStream_t stream) { @@ -81,7 +80,7 @@ void CopyInterface(std::vector &interface_arr, for (size_t c = 0; c < interface_arr.size(); ++c) { auto &interface = interface_arr.at(c); size_t element_size = interface.ElementSize(); - size_t size = element_size * interface.num_rows * interface.num_cols; + size_t size = element_size * interface.n; auto &data = (*p_data)[c]; auto &mask = (*p_mask)[c]; @@ -95,14 +94,13 @@ void CopyInterface(std::vector &interface_arr, Json{Boolean{false}}}; out["data"] = Array(std::move(j_data)); - out["shape"] = Array(std::vector{Json(Integer(interface.num_rows)), - Json(Integer(interface.num_cols))}); + out["shape"] = Array(std::vector{Json(Integer(interface.Shape(0)))}); if (interface.valid.Data()) { CopyColumnMask(interface, columns, kind, c, &mask, &out, stream); } out["typestr"] = String(" *out, cudaStream_t auto &j_interface = *p_interface; CHECK_EQ(get(j_interface).size(), 1); auto object = get(get(j_interface)[0]); - ArrayInterface interface(object); - out->resize(interface.num_rows); + ArrayInterface<1> interface(object); + out->resize(interface.Shape(0)); size_t element_size = interface.ElementSize(); - size_t size = element_size * interface.num_rows; + size_t size = element_size * interface.n; dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size, cudaMemcpyDeviceToDevice, stream)); j_interface[0]["data"][0] = reinterpret_cast(RawPtr(*out)); @@ -285,11 +283,11 @@ class DataIteratorProxy { Json features = json_interface["features_str"]; auto json_columns = get(features); - std::vector interfaces; + std::vector> interfaces; // Stage the data for (auto &json_col : json_columns) { - auto column = ArrayInterface(get(json_col)); + auto column = ArrayInterface<1>(get(json_col)); interfaces.emplace_back(column); } Json::Dump(features, &interface_str); @@ -342,9 +340,9 @@ class DataIteratorProxy { // Data auto const &json_interface = host_columns_.at(it_)->interfaces; - std::vector in; + std::vector> in; for (auto interface : json_interface) { - auto column = ArrayInterface(get(interface)); + auto column = ArrayInterface<1>(get(interface)); in.emplace_back(column); } std::vector out; diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 285b2e840f59..c739f9267af5 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable, Optional, List +from typing import Any, Tuple, Callable, Optional, List, Union import numpy as np @@ -138,14 +138,14 @@ def _is_numpy_array(data): return isinstance(data, (np.ndarray, np.matrix)) -def _ensure_np_dtype(data, dtype): +def _ensure_np_dtype(data, dtype) -> Tuple[np.ndarray, np.dtype]: if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]: data = data.astype(np.float32, copy=False) dtype = np.float32 return data, dtype -def _maybe_np_slice(data, dtype): +def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray: '''Handle numpy slice. This can be removed if we use __array_interface__. ''' try: @@ -852,23 +852,17 @@ def _validate_meta_shape(data: Any, name: str) -> None: def _meta_from_numpy( - data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p + data: np.ndarray, + field: str, + dtype: Optional[Union[np.dtype, str]], + handle: ctypes.c_void_p, ) -> None: - data = _maybe_np_slice(data, dtype) + data, dtype = _ensure_np_dtype(data, dtype) interface = data.__array_interface__ - assert interface.get('mask', None) is None, 'Masked array is not supported' - size = data.size - - c_type = _to_data_type(str(data.dtype), field) - ptr = interface['data'][0] - ptr = ctypes.c_void_p(ptr) - _check_call(_LIB.XGDMatrixSetDenseInfo( - handle, - c_str(field), - ptr, - c_bst_ulong(size), - c_type - )) + if interface.get("mask", None) is not None: + raise ValueError("Masked array is not supported.") + interface_str = _array_interface(data) + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface_str)) def _meta_from_list(data, field, dtype, handle): @@ -911,7 +905,9 @@ def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p): _meta_from_numpy(data, field, dtype, handle) -def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): +def dispatch_meta_backend( + matrix: DMatrix, data, name: str, dtype: Optional[Union[str, np.dtype]] = None +): '''Dispatch for meta info.''' handle = matrix.handle assert handle is not None diff --git a/src/common/common.cu b/src/common/common.cu index d30fbc0aeb88..4636a4cdcb7c 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -12,7 +12,8 @@ int AllVisibleGPUs() { // When compiled with CUDA but running on CPU only device, // cudaGetDeviceCount will fail. dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); - } catch(const dmlc::Error &except) { + } catch (const dmlc::Error &) { + cudaGetLastError(); // reset error. return 0; } return n_visgpus; diff --git a/src/data/adapter.h b/src/data/adapter.h index 27da8c6e3b36..c4253b7e7a0e 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -254,20 +254,20 @@ class ArrayAdapterBatch : public detail::NoMetaInfo { static constexpr bool kIsRowMajor = true; private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; class Line { - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; size_t ridx_; public: - Line(ArrayInterface array_interface, size_t ridx) + Line(ArrayInterface<2> array_interface, size_t ridx) : array_interface_{std::move(array_interface)}, ridx_{ridx} {} - size_t Size() const { return array_interface_.num_cols; } + size_t Size() const { return array_interface_.Shape(1); } COOTuple GetElement(size_t idx) const { - return {ridx_, idx, array_interface_.GetElement(ridx_, idx)}; + return {ridx_, idx, array_interface_(ridx_, idx)}; } }; @@ -277,11 +277,11 @@ class ArrayAdapterBatch : public detail::NoMetaInfo { return Line{array_interface_, idx}; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumCols() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumCols() const { return array_interface_.Shape(1); } size_t Size() const { return this->NumRows(); } - explicit ArrayAdapterBatch(ArrayInterface array_interface) + explicit ArrayAdapterBatch(ArrayInterface<2> array_interface) : array_interface_{std::move(array_interface)} {} }; @@ -294,43 +294,42 @@ class ArrayAdapter : public detail::SingleBatchDataIter { public: explicit ArrayAdapter(StringView array_interface) { auto j = Json::Load(array_interface); - array_interface_ = ArrayInterface(get(j)); + array_interface_ = ArrayInterface<2>(get(j)); batch_ = ArrayAdapterBatch{array_interface_}; } ArrayAdapterBatch const& Value() const override { return batch_; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumColumns() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumColumns() const { return array_interface_.Shape(1); } private: ArrayAdapterBatch batch_; - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; }; class CSRArrayAdapterBatch : public detail::NoMetaInfo { - ArrayInterface indptr_; - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indptr_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; bst_feature_t n_features_; class Line { - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; size_t ridx_; size_t offset_; public: - Line(ArrayInterface indices, ArrayInterface values, size_t ridx, + Line(ArrayInterface<1> indices, ArrayInterface<1> values, size_t ridx, size_t offset) : indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx}, offset_{offset} {} COOTuple GetElement(size_t idx) const { - return {ridx_, indices_.GetElement(offset_ + idx, 0), - values_.GetElement(offset_ + idx, 0)}; + return {ridx_, indices_.operator()(offset_ + idx), values_(offset_ + idx)}; } size_t Size() const { - return values_.num_rows * values_.num_cols; + return values_.Shape(0); } }; @@ -339,17 +338,16 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { public: CSRArrayAdapterBatch() = default; - CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices, - ArrayInterface values, bst_feature_t n_features) - : indptr_{std::move(indptr)}, indices_{std::move(indices)}, - values_{std::move(values)}, n_features_{n_features} { - indptr_.AsColumnVector(); - values_.AsColumnVector(); - indices_.AsColumnVector(); + CSRArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices, + ArrayInterface<1> values, bst_feature_t n_features) + : indptr_{std::move(indptr)}, + indices_{std::move(indices)}, + values_{std::move(values)}, + n_features_{n_features} { } size_t NumRows() const { - size_t size = indptr_.num_rows * indptr_.num_cols; + size_t size = indptr_.Shape(0); size = size == 0 ? 0 : size - 1; return size; } @@ -357,19 +355,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { size_t Size() const { return this->NumRows(); } Line const GetLine(size_t idx) const { - auto begin_offset = indptr_.GetElement(idx, 0); - auto end_offset = indptr_.GetElement(idx + 1, 0); + auto begin_no_stride = indptr_.operator()(idx); + auto end_no_stride = indptr_.operator()(idx + 1); auto indices = indices_; auto values = values_; + // Slice indices and values, stride remains unchanged since this is slicing by + // specific index. + auto offset = indices.strides[0] * begin_no_stride; - values.num_cols = end_offset - begin_offset; - values.num_rows = 1; + indices.shape[0] = end_no_stride - begin_no_stride; + values.shape[0] = end_no_stride - begin_no_stride; - indices.num_cols = values.num_cols; - indices.num_rows = values.num_rows; - - return Line{indices, values, idx, begin_offset}; + return Line{indices, values, idx, offset}; } }; @@ -391,7 +389,7 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter return batch_; } size_t NumRows() const { - size_t size = indptr_.num_cols * indptr_.num_rows; + size_t size = indptr_.Shape(0); size = size == 0 ? 0 : size - 1; return size; } @@ -399,9 +397,9 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter private: CSRArrayAdapterBatch batch_; - ArrayInterface indptr_; - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indptr_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; size_t num_cols_; }; diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index def4de195523..08dbfaefb647 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -7,15 +7,50 @@ namespace xgboost { void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { switch (stream) { - case 0: - LOG(FATAL) << "Invalid stream ID in array interface: " << stream; - case 1: - // default legacy stream - break; - case 2: - // default per-thread stream - default: - dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); + case 0: + /** + * disallowed by the `__cuda_array_interface__`. Quote: + * + * This is disallowed as it would be ambiguous between None and the default + * stream, and also between the legacy and per-thread default streams. Any use + * case where 0 might be given should either use None, 1, or 2 instead for + * clarity. + */ + LOG(FATAL) << "Invalid stream ID in array interface: " << stream; + case 1: + // default legacy stream + break; + case 2: + // default per-thread stream + default: + dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); + } +} + +bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { + if (!ptr) { + return false; + } + cudaPointerAttributes attr; + auto err = cudaPointerGetAttributes(&attr, ptr); + // reset error + CHECK_EQ(err, cudaGetLastError()); + if (err == cudaErrorInvalidValue) { + // CUDA < 11 + return false; + } else if (err == cudaSuccess) { + // CUDA >= 11 + switch (attr.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + return false; + default: + return true; + } + return true; + } else { + // other errors, no `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. + return false; } } } // namespace xgboost diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 6524f4512407..25d2361a9182 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -13,24 +13,23 @@ #include #include +#include "../common/bitfield.h" +#include "../common/common.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "xgboost/logging.h" #include "xgboost/span.h" -#include "../common/bitfield.h" -#include "../common/common.h" namespace xgboost { // Common errors in parsing columnar format. struct ArrayInterfaceErrors { - static char const* Contigious() { - return "Memory should be contigious."; - } - static char const* TypestrFormat() { + static char const *Contiguous() { return "Memory should be contiguous."; } + static char const *TypestrFormat() { return "`typestr' should be of format ."; } - static char const* Dimension(int32_t d) { + static char const *Dimension(int32_t d) { static std::string str; str.clear(); str += "Only "; @@ -38,11 +37,11 @@ struct ArrayInterfaceErrors { str += " dimensional array is valid."; return str.c_str(); } - static char const* Version() { - return "Only version <= 3 of " - "`__cuda_array_interface__/__array_interface__' are supported."; + static char const *Version() { + return "Only version <= 3 of `__cuda_array_interface__' and `__array_interface__' are " + "supported."; } - static char const* OfType(std::string const& type) { + static char const *OfType(std::string const &type) { static std::string str; str.clear(); str += " should be of "; @@ -96,38 +95,25 @@ struct ArrayInterfaceErrors { // object and turn it into an array (for cupy and numba). class ArrayInterfaceHandler { public: - template - static constexpr char TypeChar() { - return - (std::is_floating_point::value ? 'f' : - (std::is_integral::value ? - (std::is_signed::value ? 'i' : 'u') : '\0')); - } + enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; template - static PtrType GetPtrFromArrayData(std::map const& obj) { + static PtrType GetPtrFromArrayData(std::map const &obj) { if (obj.find("data") == obj.cend()) { LOG(FATAL) << "Empty data passed in."; } - auto p_data = reinterpret_cast(static_cast( - get( - get( - obj.at("data")) - .at(0)))); + auto p_data = reinterpret_cast( + static_cast(get(get(obj.at("data")).at(0)))); return p_data; } - static void Validate(std::map const& array) { + static void Validate(std::map const &array) { auto version_it = array.find("version"); if (version_it == array.cend()) { LOG(FATAL) << "Missing `version' field for array interface"; } - auto stream_it = array.find("stream"); - if (stream_it != array.cend() && !IsA(stream_it->second)) { - // is cuda, check the version. - if (get(version_it->second) > 3) { - LOG(FATAL) << ArrayInterfaceErrors::Version(); - } + if (get(version_it->second) > 3) { + LOG(FATAL) << ArrayInterfaceErrors::Version(); } if (array.find("typestr") == array.cend()) { @@ -149,12 +135,12 @@ class ArrayInterfaceHandler { // Mask object is also an array interface, but with different requirements. static size_t ExtractMask(std::map const &column, common::Span *p_out) { - auto& s_mask = *p_out; + auto &s_mask = *p_out; if (column.find("mask") != column.cend()) { - auto const& j_mask = get(column.at("mask")); + auto const &j_mask = get(column.at("mask")); Validate(j_mask); - auto p_mask = GetPtrFromArrayData(j_mask); + auto p_mask = GetPtrFromArrayData(j_mask); auto j_shape = get(j_mask.at("shape")); CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1); @@ -186,8 +172,8 @@ class ArrayInterfaceHandler { if (j_mask.find("strides") != j_mask.cend()) { auto strides = get(column.at("strides")); - CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); - CHECK_EQ(get(strides.at(0)), type_length) << ArrayInterfaceErrors::Contigious(); + CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); + CHECK_EQ(get(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous(); } s_mask = {p_mask, span_size}; @@ -195,96 +181,225 @@ class ArrayInterfaceHandler { } return 0; } - - static std::pair ExtractShape( - std::map const& column) { - auto j_shape = get(column.at("shape")); - auto typestr = get(column.at("typestr")); - if (j_shape.size() == 1) { - return {static_cast(get(j_shape.at(0))), 1}; - } else { - CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported."; - return {static_cast(get(j_shape.at(0))), - static_cast(get(j_shape.at(1)))}; + /** + * \brief Handle vector inputs. For higher dimension, we require strictly correct shape. + */ + template + static void HandleRowVector(std::vector const &shape, std::vector *p_out) { + auto &out = *p_out; + if (shape.size() == 2 && D == 1) { + auto m = shape[0]; + auto n = shape[1]; + CHECK(m == 1 || n == 1); + if (m == 1) { + // keep the number of columns + out[0] = out[1]; + out.resize(1); + } else if (n == 1) { + // keep the number of rows. + out.resize(1); + } + // when both m and n are 1, above logic keeps the column. + // when neither m nor n is 1, caller should throw an error about Dimension. } } - static void ExtractStride(std::map const &column, - size_t *stride_r, size_t *stride_c, size_t rows, - size_t cols, size_t itemsize) { - auto strides_it = column.find("strides"); - if (strides_it == column.cend() || IsA(strides_it->second)) { - // default strides - *stride_r = cols; - *stride_c = 1; - } else { - // strides specified by the array interface - auto const &j_strides = get(strides_it->second); - CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2); - *stride_r = get(j_strides[0]) / itemsize; - size_t n = 1; - if (j_strides.size() == 2) { - n = get(j_strides[1]) / itemsize; - } - *stride_c = n; + template + static void ExtractShape(std::map const &array, size_t (&out_shape)[D]) { + auto const &j_shape = get(array.at("shape")); + std::vector shape_arr(j_shape.size(), 0); + std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(), + [](Json in) { return get(in); }); + // handle column vector vs. row vector + HandleRowVector(shape_arr, &shape_arr); + // Copy shape. + size_t i; + for (i = 0; i < shape_arr.size(); ++i) { + CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D); + out_shape[i] = shape_arr[i]; } + // Fill the remaining dimensions + std::fill(out_shape + i, out_shape + D, 1); + } - auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols); - CHECK(valid) << "Invalid strides in array." - << " strides: (" << (*stride_r) << "," << (*stride_c) - << "), shape: (" << rows << ", " << cols << ")"; + template + static bool ExtractStride(std::map const &array, size_t itemsize, + size_t (&shape)[D], size_t (&stride)[D]) { + auto strides_it = array.find("strides"); + // No stride is provided + if (strides_it == array.cend() || IsA(strides_it->second)) { + // No stride is provided, we can calculate it from shape. + linalg::detail::CalcStride(shape, stride); + // Quote: + // + // strides: Either None to indicate a C-style contiguous array or a Tuple of + // strides which provides the number of bytes + return true; + } + // Get shape, we need to make changes to handle row vector, so some duplicated code + // from `ExtractShape` for copying out the shape. + auto const &j_shape = get(array.at("shape")); + std::vector shape_arr(j_shape.size(), 0); + std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(), + [](Json in) { return get(in); }); + // Get stride + auto const &j_strides = get(strides_it->second); + CHECK_EQ(j_strides.size(), j_shape.size()) << "stride and shape don't match."; + std::vector stride_arr(j_strides.size(), 0); + std::transform(j_strides.cbegin(), j_strides.cend(), stride_arr.begin(), + [](Json in) { return get(in); }); + + // Handle column vector vs. row vector + HandleRowVector(shape_arr, &stride_arr); + size_t i; + for (i = 0; i < stride_arr.size(); ++i) { + // If one of the dim has shape 0 then total size is 0, stride is meaningless, but we + // set it to 0 here just to be consistent + CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D); + // We use number of items instead of number of bytes + stride[i] = stride_arr[i] / itemsize; + } + std::fill(stride + i, stride + D, 1); + // If the stride can be calculated from shape then it's contiguous. + size_t stride_tmp[D]; + linalg::detail::CalcStride(shape, stride_tmp); + return std::equal(stride_tmp, stride_tmp + D, stride); } - static void* ExtractData(std::map const &column, - std::pair shape) { - Validate(column); - void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); + static void *ExtractData(std::map const &array, size_t size) { + Validate(array); + void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData(array); if (!p_data) { - CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape."; + CHECK_EQ(size, 0) << "Empty data with non-zero shape."; } return p_data; } - + /** + * \brief Whether the ptr is allocated by CUDA. + */ + static bool IsCudaPtr(void const *ptr); + /** + * \brief Sync the CUDA stream. + */ static void SyncCudaStream(int64_t stream); }; +/** + * Dispatch compile time type to runtime type. + */ +template +struct ToDType; +// float +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF8; +}; +template +struct ToDType::value && sizeof(long double) == 16>> { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF16; +}; +// uint +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU1; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU2; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU8; +}; +// int +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI1; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI2; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI8; +}; + #if !defined(XGBOOST_USE_CUDA) -inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { - common::AssertGPUSupport(); -} +inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { common::AssertGPUSupport(); } +inline bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { return false; } #endif // !defined(XGBOOST_USE_CUDA) // A view over __array_interface__ +/** + * \brief A type erased view over __array_interface__ protocol defined by numpy + * + * numpy. + * + * \tparam D The number of maximum dimension. + + * User input array must have dim <= D for all non-trivial dimensions. During + * construction, the ctor can automatically remove those trivial dimensions. + * + * \tparam allow_mask Whether masked array is accepted. + * + * Currently this only supported for 1-dim vector, which is used by cuDF column + * (apache arrow format). For general masked array, as the time of writting, only + * numpy has the proper support even though it's in the __cuda_array_interface__ + * protocol defined by numba. + */ +template class ArrayInterface { - void Initialize(std::map const &array, - bool allow_mask = true) { + static_assert(D > 0, "Invalid dimension for array interface."); + + /** + * \brief Initialize the object, by extracting shape, stride and type. + * + * The function also perform some basic validation for input array. Lastly it will + * also remove trivial dimensions like converting a matrix with shape (n_samples, 1) + * to a vector of size n_samples. For for inputs like weights, this should be a 1 + * dimension column vector even though user might provide a matrix. + */ + void Initialize(std::map const &array) { ArrayInterfaceHandler::Validate(array); + auto typestr = get(array.at("typestr")); this->AssignType(StringView{typestr}); + ArrayInterfaceHandler::ExtractShape(array, shape); + size_t itemsize = typestr[2] - '0'; + is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides); + n = linalg::detail::CalcSize(shape); - std::tie(num_rows, num_cols) = ArrayInterfaceHandler::ExtractShape(array); - data = ArrayInterfaceHandler::ExtractData( - array, std::make_pair(num_rows, num_cols)); + data = ArrayInterfaceHandler::ExtractData(array, n); if (allow_mask) { + CHECK(D == 1) << "Masked array is not supported."; common::Span s_mask; size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask); valid = RBitField8(s_mask); if (s_mask.data()) { - CHECK_EQ(n_bits, num_rows) - << "Shape of bit mask doesn't match data shape. " - << "XGBoost doesn't support internal broadcasting."; + CHECK_EQ(n_bits, n) << "Shape of bit mask doesn't match data shape. " + << "XGBoost doesn't support internal broadcasting."; } } else { - CHECK(array.find("mask") == array.cend()) - << "Masked array is not yet supported."; + CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported."; } - ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col, - num_rows, num_cols, typestr[2] - '0'); - auto stream_it = array.find("stream"); if (stream_it != array.cend() && !IsA(stream_it->second)) { int64_t stream = get(stream_it->second); @@ -292,151 +407,149 @@ class ArrayInterface { } } - public: - enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; - public: ArrayInterface() = default; - explicit ArrayInterface(std::string const &str, bool allow_mask = true) - : ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {} - - explicit ArrayInterface(std::map const &column, - bool allow_mask = true) { - this->Initialize(column, allow_mask); + explicit ArrayInterface(std::map const &array) { + this->Initialize(array); } - explicit ArrayInterface(StringView str, bool allow_mask = true) { - auto jinterface = Json::Load(str); - if (IsA(jinterface)) { - this->Initialize(get(jinterface), allow_mask); + explicit ArrayInterface(Json const &array) { + if (IsA(array)) { + this->Initialize(get(array)); return; } - if (IsA(jinterface)) { - CHECK_EQ(get(jinterface).size(), 1) + if (IsA(array)) { + CHECK_EQ(get(array).size(), 1) << "Column: " << ArrayInterfaceErrors::Dimension(1); - this->Initialize(get(get(jinterface)[0]), allow_mask); + this->Initialize(get(get(array)[0])); return; } } - void AsColumnVector() { - CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix."; - num_rows = std::max(num_rows, static_cast(num_cols)); - num_cols = 1; + explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {} - stride_row = std::max(stride_row, stride_col); - stride_col = 1; - } + explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {} void AssignType(StringView typestr) { - if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && - typestr[3] == '6') { - type = kF16; + using T = ArrayInterfaceHandler::Type; + if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') { + type = T::kF16; CHECK(sizeof(long double) == 16) << "128-bit floating point is not supported on current platform."; } else if (typestr[1] == 'f' && typestr[2] == '4') { - type = kF4; + type = T::kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { - type = kF8; + type = T::kF8; } else if (typestr[1] == 'i' && typestr[2] == '1') { - type = kI1; + type = T::kI1; } else if (typestr[1] == 'i' && typestr[2] == '2') { - type = kI2; + type = T::kI2; } else if (typestr[1] == 'i' && typestr[2] == '4') { - type = kI4; + type = T::kI4; } else if (typestr[1] == 'i' && typestr[2] == '8') { - type = kI8; + type = T::kI8; } else if (typestr[1] == 'u' && typestr[2] == '1') { - type = kU1; + type = T::kU1; } else if (typestr[1] == 'u' && typestr[2] == '2') { - type = kU2; + type = T::kU2; } else if (typestr[1] == 'u' && typestr[2] == '4') { - type = kU4; + type = T::kU4; } else if (typestr[1] == 'u' && typestr[2] == '8') { - type = kU8; + type = T::kU8; } else { LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr); return; } } + XGBOOST_DEVICE size_t Shape(size_t i) const { return shape[i]; } + XGBOOST_DEVICE size_t Stride(size_t i) const { return strides[i]; } + template - XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const { + XGBOOST_HOST_DEV_INLINE constexpr decltype(auto) DispatchCall(Fn func) const { + using T = ArrayInterfaceHandler::Type; switch (type) { - case kF4: - return func(reinterpret_cast(data)); - case kF8: - return func(reinterpret_cast(data)); + case T::kF4: + return func(reinterpret_cast(data)); + case T::kF8: + return func(reinterpret_cast(data)); #ifdef __CUDA_ARCH__ - case kF16: { - // CUDA device code doesn't support long double. - SPAN_CHECK(false); - return func(reinterpret_cast(data)); - } + case T::kF16: { + // CUDA device code doesn't support long double. + SPAN_CHECK(false); + return func(reinterpret_cast(data)); + } #else - case kF16: - return func(reinterpret_cast(data)); + case T::kF16: + return func(reinterpret_cast(data)); #endif - case kI1: - return func(reinterpret_cast(data)); - case kI2: - return func(reinterpret_cast(data)); - case kI4: - return func(reinterpret_cast(data)); - case kI8: - return func(reinterpret_cast(data)); - case kU1: - return func(reinterpret_cast(data)); - case kU2: - return func(reinterpret_cast(data)); - case kU4: - return func(reinterpret_cast(data)); - case kU8: - return func(reinterpret_cast(data)); + case T::kI1: + return func(reinterpret_cast(data)); + case T::kI2: + return func(reinterpret_cast(data)); + case T::kI4: + return func(reinterpret_cast(data)); + case T::kI8: + return func(reinterpret_cast(data)); + case T::kU1: + return func(reinterpret_cast(data)); + case T::kU2: + return func(reinterpret_cast(data)); + case T::kU4: + return func(reinterpret_cast(data)); + case T::kU8: + return func(reinterpret_cast(data)); } SPAN_CHECK(false); return func(reinterpret_cast(data)); } - XGBOOST_DEVICE size_t ElementSize() { - return this->DispatchCall([](auto* p_values) { - return sizeof(std::remove_pointer_t); - }); + XGBOOST_DEVICE size_t constexpr ElementSize() { + return this->DispatchCall( + [](auto *p_values) { return sizeof(std::remove_pointer_t); }); } - template - XGBOOST_DEVICE T GetElement(size_t r, size_t c) const { - return this->DispatchCall( - [=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; }); + template + XGBOOST_DEVICE T operator()(Index &&...index) const { + static_assert(sizeof...(index) <= D, "Invalid index."); + return this->DispatchCall([=](auto const *p_values) -> T { + size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...); + return static_cast(p_values[offset]); + }); } + // Used only by columnar format. RBitField8 valid; - bst_row_t num_rows; - bst_feature_t num_cols; - size_t stride_row{0}; - size_t stride_col{0}; - void* data; - Type type; + // Array stride + size_t strides[D]{0}; + // Array shape + size_t shape[D]{0}; + // Type earsed pointer referencing the data. + void *data; + // Total number of items + size_t n; + // Whether the memory is c-contiguous + bool is_contiguous {false}; + // RTTI + ArrayInterfaceHandler::Type type; }; -template std::string MakeArrayInterface(T const *data, size_t n) { - Json arr{Object{}}; - arr["data"] = Array(std::vector{ - Json{Integer{reinterpret_cast(data)}}, Json{Boolean{false}}}); - arr["shape"] = Array{std::vector{Json{Integer{n}}, Json{Integer{1}}}}; - std::string typestr; - if (DMLC_LITTLE_ENDIAN) { - typestr.push_back('<'); - } else { - typestr.push_back('>'); +/** + * \brief Helper for type casting. + */ +template +struct TypedIndex { + ArrayInterface const &array; + template + XGBOOST_DEVICE T operator()(I &&...ind) { + static_assert(sizeof...(ind) <= D, "Invalid index."); + return array.template operator()(ind...); } - typestr.push_back(ArrayInterfaceHandler::TypeChar()); - typestr += std::to_string(sizeof(T)); - arr["typestr"] = typestr; - arr["version"] = 3; - std::string str; - Json::Dump(arr, &str); - return str; +}; + +template +inline void CheckArrayInterface(StringView key, ArrayInterface const &array) { + CHECK(!array.valid.Data()) << "Meta info " << key << " should be dense, found validity mask"; } } // namespace xgboost #endif // XGBOOST_DATA_ARRAY_INTERFACE_H_ diff --git a/src/data/data.cc b/src/data/data.cc index acf6e47b9d21..d9f489103f22 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,8 +1,9 @@ /*! - * Copyright 2015-2020 by Contributors + * Copyright 2015-2021 by Contributors * \file data.cc */ #include +#include #include #include "dmlc/io.h" @@ -12,10 +13,13 @@ #include "xgboost/logging.h" #include "xgboost/version_config.h" #include "xgboost/learner.h" +#include "xgboost/string_view.h" + #include "sparse_page_writer.h" #include "simple_dmatrix.h" #include "../common/io.h" +#include "../common/linalg_op.h" #include "../common/math.h" #include "../common/version.h" #include "../common/group_data.h" @@ -24,6 +28,7 @@ #include "../data/iterative_device_dmatrix.h" #include "file_iterator.h" +#include "validation.h" #include "./sparse_page_source.h" #include "./sparse_page_dmatrix.h" @@ -65,10 +70,22 @@ void SaveVectorField(dmlc::Stream* strm, const std::string& name, SaveVectorField(strm, name, type, shape, field.ConstHostVector()); } +template +void SaveTensorField(dmlc::Stream* strm, const std::string& name, xgboost::DataType type, + const xgboost::linalg::Tensor& field) { + strm->Write(name); + strm->Write(static_cast(type)); + strm->Write(false); // is_scalar=False + for (size_t i = 0; i < D; ++i) { + strm->Write(field.Shape(i)); + } + strm->Write(field.Data()->HostVector()); +} + template void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, T* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -90,7 +107,7 @@ void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, template void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, std::vector* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -123,6 +140,33 @@ void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, LoadVectorField(strm, expected_name, expected_type, &field->HostVector()); } +template +void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name, + xgboost::DataType expected_type, xgboost::linalg::Tensor* p_out) { + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; + std::string name; + xgboost::DataType type; + bool is_scalar; + CHECK(strm->Read(&name)) << invalid; + CHECK_EQ(name, expected_name) << invalid << " Expected field: " << expected_name + << ", got: " << name; + uint8_t type_val; + CHECK(strm->Read(&type_val)) << invalid; + type = static_cast(type_val); + CHECK(type == expected_type) << invalid + << "Expected field of type: " << static_cast(expected_type) + << ", " + << "got field type: " << static_cast(type); + CHECK(strm->Read(&is_scalar)) << invalid; + CHECK(!is_scalar) << invalid << "Expected field " << expected_name + << " to be a tensor; got a scalar"; + std::array shape; + for (size_t i = 0; i < D; ++i) { + CHECK(strm->Read(&(shape[i]))); + } + auto& field = p_out->Data()->HostVector(); + CHECK(strm->Read(&field)) << invalid; +} } // anonymous namespace namespace xgboost { @@ -135,25 +179,26 @@ void MetaInfo::Clear() { labels_.HostVector().clear(); group_ptr_.clear(); weights_.HostVector().clear(); - base_margin_.HostVector().clear(); + base_margin_ = decltype(base_margin_){}; } /* * Binary serialization format for MetaInfo: * - * | name | type | is_scalar | num_row | num_col | value | - * |--------------------+----------+-----------+---------+---------+-------------------------| - * | num_row | kUInt64 | True | NA | NA | ${num_row_} | - * | num_col | kUInt64 | True | NA | NA | ${num_col_} | - * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | - * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | - * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | - * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | - * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | - * | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} | - * | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} | - * | feature_names | kStr | False | ${size} | 1 | ${feature_names} | - * | feature_types | kStr | False | ${size} | 1 | ${feature_types} | + * | name | type | is_scalar | num_row | num_col | dim3 | value | + * |--------------------+----------+-----------+-------------+-------------+-------------+------------------------| + * | num_row | kUInt64 | True | NA | NA | NA | ${num_row_} | + * | num_col | kUInt64 | True | NA | NA | NA | ${num_col_} | + * | num_nonzero | kUInt64 | True | NA | NA | NA | ${num_nonzero_} | + * | labels | kFloat32 | False | ${size} | 1 | NA | ${labels_} | + * | group_ptr | kUInt32 | False | ${size} | 1 | NA | ${group_ptr_} | + * | weights | kFloat32 | False | ${size} | 1 | NA | ${weights_} | + * | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${Shape(2)} | ${base_margin_} | + * | labels_lower_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_lower_bound_} | + * | labels_upper_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_upper_bound_} | + * | feature_names | kStr | False | ${size} | 1 | NA | ${feature_names} | + * | feature_types | kStr | False | ${size} | 1 | NA | ${feature_types} | + * | feature_types | kFloat32 | False | ${size} | 1 | NA | ${feature_weights} | * * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. * Also notice the difference between the saved name and the name used in `SetInfo': @@ -174,8 +219,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {group_ptr_.size(), 1}, group_ptr_); ++field_cnt; SaveVectorField(fo, u8"weights", DataType::kFloat32, {weights_.Size(), 1}, weights_); ++field_cnt; - SaveVectorField(fo, u8"base_margin", DataType::kFloat32, - {base_margin_.Size(), 1}, base_margin_); ++field_cnt; + SaveTensorField(fo, u8"base_margin", DataType::kFloat32, base_margin_); ++field_cnt; SaveVectorField(fo, u8"labels_lower_bound", DataType::kFloat32, {labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt; SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, @@ -185,6 +229,9 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {feature_names.size(), 1}, feature_names); ++field_cnt; SaveVectorField(fo, u8"feature_types", DataType::kStr, {feature_type_names.size(), 1}, feature_type_names); ++field_cnt; + SaveVectorField(fo, u8"feature_weights", DataType::kFloat32, {feature_weights.Size(), 1}, + feature_weights); + ++field_cnt; CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } @@ -213,10 +260,14 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { auto major = std::get<0>(version); // MetaInfo is saved in `SparsePageSource'. So the version in MetaInfo represents the // version of DMatrix. - CHECK_EQ(major, 1) << "Binary DMatrix generated by XGBoost: " - << Version::String(version) << " is no longer supported. " - << "Please process and save your data in current version: " - << Version::String(Version::Self()) << " again."; + std::stringstream msg; + msg << "Binary DMatrix generated by XGBoost: " << Version::String(version) + << " is no longer supported. " + << "Please process and save your data in current version: " + << Version::String(Version::Self()) << " again."; + CHECK_EQ(major, 1) << msg.str(); + auto minor = std::get<1>(version); + CHECK_GE(minor, 6) << msg.str(); const uint64_t expected_num_field = kNumField; uint64_t num_field { 0 }; @@ -243,12 +294,13 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_); LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); - LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); + LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_); LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names); LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names); + LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); LoadFeatureType(feature_type_names, &feature_types.HostVector()); } @@ -291,10 +343,13 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; - size_t stride = this->base_margin_.Size() / this->num_row_; - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride); + auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); + out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1], margin.Shape()[2]); + size_t stride = margin.Stride(0); + out.base_margin_.Data()->HostVector() = + Gather(this->base_margin_.Data()->HostVector(), ridxs, stride); } else { - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); + out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs); } out.feature_weights.Resize(this->feature_weights.Size()); @@ -337,116 +392,179 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, return true; } -void ValidateQueryGroup(std::vector const &group_ptr_) { - bool valid_query_group = true; - for (size_t i = 1; i < group_ptr_.size(); ++i) { - valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1]; - if (!valid_query_group) { - break; - } +namespace { +template +void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { + ArrayInterface array{arr_interface}; + if (array.n == 0) { + return; + } + CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + if (array.is_contiguous && array.type == ToDType::kType) { + // Handle contigious + p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + // set shape + std::copy(array.shape, array.shape + D, shape.data()); + // set data + data->Resize(array.n); + std::memcpy(data->HostPointer(), array.data, array.n * sizeof(T)); + }); + return; } - CHECK(valid_query_group) << "Invalid group structure."; + p_out->Reshape(array.shape); + auto t = p_out->View(GenericParameter::kCpuId); + CHECK(t.Contiguous()); + // FIXME(jiamingy): Remove the use of this default thread. + linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); + }); } +} // namespace + +void MetaInfo::SetInfo(StringView key, StringView interface_str) { + Json j_interface = Json::Load(interface_str); + bool is_cuda{false}; + if (IsA(j_interface)) { + auto const& array = get(j_interface); + CHECK_GE(array.size(), 0) << "Invalid " << key + << ", must have at least 1 column even if it's empty."; + auto const& first = get(array.front()); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } else { + auto const& first = get(j_interface); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } -// macro to dispatch according to specified pointer types -#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ - switch (dtype) { \ - case xgboost::DataType::kFloat32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kDouble: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt64: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - default: LOG(FATAL) << "Unknown data type" << static_cast(dtype); \ - } \ + if (is_cuda) { + this->SetInfoFromCUDA(key, j_interface); + } else { + this->SetInfoFromHost(key, j_interface); + } +} -void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - if (!std::strcmp(key, "label")) { - auto& labels = labels_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) { - return std::isnan(y) || std::isinf(y); - }); - CHECK(valid) << "Label contains NaN, infinity or a value too large."; - } else if (!std::strcmp(key, "weight")) { - auto& weights = weights_.HostVector(); - weights.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, weights.begin())); - auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) { - return w < 0 || std::isinf(w) || std::isnan(w); - }); - CHECK(valid) << "Weights must be positive values."; - } else if (!std::strcmp(key, "base_margin")) { - auto& base_margin = base_margin_.HostVector(); - base_margin.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, base_margin.begin())); - } else if (!std::strcmp(key, "group")) { - group_ptr_.clear(); group_ptr_.resize(num + 1, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1)); - group_ptr_[0] = 0; - for (size_t i = 1; i < group_ptr_.size(); ++i) { - group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; +void MetaInfo::SetInfoFromHost(StringView key, Json arr) { + // multi-dim float info + if (key == "base_margin") { + CopyTensorInfoImpl<3>(arr, &this->base_margin_); + // FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of + // input shape. This issue is CPU only since CUDA uses array interface from day 1. + // + // Python binding always understand the shape, so this condition should not occur for + // it. + if (this->num_row_ != 0 && this->base_margin_.Shape(0) != this->num_row_) { + // API functions that don't use array interface don't understand shape. + CHECK(this->base_margin_.Size() % this->num_row_ == 0) << "Incorrect size for base margin."; + size_t n_groups = this->base_margin_.Size() / this->num_row_; + this->base_margin_.Reshape(this->num_row_, n_groups); } - ValidateQueryGroup(group_ptr_); - } else if (!std::strcmp(key, "qid")) { - std::vector query_ids(num, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, query_ids.begin())); + return; + } + // uint info + if (key == "group") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); + auto const& h_groups = t.Data()->HostVector(); + group_ptr_.clear(); + group_ptr_.resize(h_groups.size() + 1, 0); + group_ptr_[0] = 0; + std::partial_sum(h_groups.cbegin(), h_groups.cend(), group_ptr_.begin() + 1); + data::ValidateQueryGroup(group_ptr_); + return; + } else if (key == "qid") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); bool non_dec = true; + auto const& query_ids = t.Data()->HostVector(); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] < query_ids[i-1]) { + if (query_ids[i] < query_ids[i - 1]) { non_dec = false; break; } } CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data."; - group_ptr_.clear(); group_ptr_.push_back(0); + group_ptr_.clear(); + group_ptr_.push_back(0); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] != query_ids[i-1]) { + if (query_ids[i] != query_ids[i - 1]) { group_ptr_.push_back(i); } } if (group_ptr_.back() != query_ids.size()) { group_ptr_.push_back(query_ids.size()); } - } else if (!std::strcmp(key, "label_lower_bound")) { - auto& labels = labels_lower_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "label_upper_bound")) { - auto& labels = labels_upper_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "feature_weights")) { - auto &h_feature_weights = feature_weights.HostVector(); - h_feature_weights.resize(num); - DISPATCH_CONST_PTR( - dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin())); + data::ValidateQueryGroup(group_ptr_); + return; + } + // float info + linalg::Tensor t; + CopyTensorInfoImpl<1>(arr, &t); + if (key == "label") { + this->labels_ = std::move(*t.Data()); + auto const& h_labels = labels_.ConstHostVector(); + auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; + } else if (key == "weight") { + this->weights_ = std::move(*t.Data()); + auto const& h_weights = this->weights_.ConstHostVector(); + auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(), + [](float w) { return w < 0 || std::isinf(w) || std::isnan(w); }); + CHECK(valid) << "Weights must be positive values."; + } else if (key == "label_lower_bound") { + this->labels_lower_bound_ = std::move(*t.Data()); + } else if (key == "label_upper_bound") { + this->labels_upper_bound_ = std::move(*t.Data()); + } else if (key == "feature_weights") { + this->feature_weights = std::move(*t.Data()); + auto const& h_feature_weights = feature_weights.ConstHostVector(); bool valid = - std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { return w < 0; }); + std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; } else { LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } -void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, - const void **out_dptr) const { +void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { + auto proc = [&](auto cast_d_ptr) { + using T = std::remove_pointer_t; + auto t = + linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, GenericParameter::kCpuId); + CHECK(t.Contiguous()); + Json interface { t.ArrayInterface() }; + assert(ArrayInterface<1>{interface}.is_contiguous); + return interface; + }; + // Legacy code using XGBoost dtype, which is a small subset of array interface types. + switch (dtype) { + case xgboost::DataType::kFloat32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kDouble: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt64: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + default: + LOG(FATAL) << "Unknown data type" << static_cast(dtype); + } +} + +void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, + const void** out_dptr) const { if (dtype == DataType::kFloat32) { const std::vector* vec = nullptr; if (!std::strcmp(key, "label")) { @@ -454,7 +572,7 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, } else if (!std::strcmp(key, "weight")) { vec = &this->weights_.HostVector(); } else if (!std::strcmp(key, "base_margin")) { - vec = &this->base_margin_.HostVector(); + vec = &this->base_margin_.Data()->HostVector(); } else if (!std::strcmp(key, "label_lower_bound")) { vec = &this->labels_lower_bound_.HostVector(); } else if (!std::strcmp(key, "label_upper_bound")) { @@ -543,8 +661,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx()); this->labels_upper_bound_.Extend(that.labels_upper_bound_); - this->base_margin_.SetDevice(that.base_margin_.DeviceIdx()); - this->base_margin_.Extend(that.base_margin_); + linalg::Stack(&this->base_margin_, that.base_margin_); if (this->group_ptr_.size() == 0) { this->group_ptr_ = that.group_ptr_; @@ -627,12 +744,12 @@ void MetaInfo::Validate(int32_t device) const { if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(base_margin_); + check_device(*base_margin_.Data()); } } #if !defined(XGBOOST_USE_CUDA) -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { +void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) @@ -788,10 +905,10 @@ DMatrix* DMatrix::Load(const std::string& uri, LOG(CONSOLE) << info.group_ptr_.size() - 1 << " groups are loaded from " << fname << ".group"; } - if (MetaTryLoadFloatInfo - (fname + ".base_margin", &info.base_margin_.HostVector()) && !silent) { - LOG(CONSOLE) << info.base_margin_.Size() - << " base_margin are loaded from " << fname << ".base_margin"; + if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) && + !silent) { + LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname + << ".base_margin"; } if (MetaTryLoadFloatInfo (fname + ".weight", &info.weights_.HostVector()) && !silent) { diff --git a/src/data/data.cu b/src/data/data.cu index aee62d1b7fad..b3bb2ccc6d61 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -9,84 +9,79 @@ #include "xgboost/json.h" #include "array_interface.h" #include "../common/device_helpers.cuh" +#include "../common/linalg_op.cuh" #include "device_adapter.cuh" #include "simple_dmatrix.h" +#include "validation.h" namespace xgboost { - -void CopyInfoImpl(ArrayInterface column, HostDeviceVector* out) { - auto SetDeviceToPtr = [](void* ptr) { - cudaPointerAttributes attr; - dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); - int32_t ptr_device = attr.device; - if (ptr_device >= 0) { - dh::safe_cuda(cudaSetDevice(ptr_device)); - } - return ptr_device; - }; - auto ptr_device = SetDeviceToPtr(column.data); - - if (column.num_rows == 0) { - return; - } - out->SetDevice(ptr_device); - - size_t size = column.num_rows * column.num_cols; - CHECK_NE(size, 0); - out->Resize(size); - - auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); - dh::LaunchN(size, [=] __device__(size_t idx) { - size_t ridx = idx / column.num_cols; - size_t cidx = idx - (ridx * column.num_cols); - p_dst[idx] = column.GetElement(ridx, cidx); - }); -} - namespace { -auto SetDeviceToPtr(void *ptr) { +auto SetDeviceToPtr(void* ptr) { cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); int32_t ptr_device = attr.device; dh::safe_cuda(cudaSetDevice(ptr_device)); return ptr_device; } -} // anonymous namespace -void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { - CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8) +template +void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { + ArrayInterface array(arr_interface); + if (array.n == 0) { + return; + } + CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + auto ptr_device = SetDeviceToPtr(array.data); + + if (array.is_contiguous && array.type == ToDType::kType) { + p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + // set shape + std::copy(array.shape, array.shape + D, shape.data()); + // set data + data->SetDevice(ptr_device); + data->Resize(array.n); + dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), + cudaMemcpyDefault)); + }); + return; + } + p_out->SetDevice(ptr_device); + p_out->Reshape(array.shape); + auto t = p_out->View(ptr_device); + linalg::ElementWiseKernelDevice(t, [=] __device__(size_t i, T) { + return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, array.shape)); + }); +} + +void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector* out) { + CHECK(column.type != ArrayInterfaceHandler::kF4 && column.type != ArrayInterfaceHandler::kF8) << "Expected integer for group info."; auto ptr_device = SetDeviceToPtr(column.data); CHECK_EQ(ptr_device, dh::CurrentDevice()); - dh::TemporaryArray temp(column.num_rows); - auto d_tmp = temp.data(); + dh::TemporaryArray temp(column.Shape(0)); + auto d_tmp = temp.data().get(); - dh::LaunchN(column.num_rows, [=] __device__(size_t idx) { - d_tmp[idx] = column.GetElement(idx, 0); - }); - auto length = column.num_rows; + dh::LaunchN(column.Shape(0), + [=] __device__(size_t idx) { d_tmp[idx] = column.operator()(idx); }); + auto length = column.Shape(0); out->resize(length + 1); out->at(0) = 0; thrust::copy(temp.data(), temp.data() + length, out->begin() + 1); std::partial_sum(out->begin(), out->end(), out->begin()); } -void CopyQidImpl(ArrayInterface array_interface, - std::vector *p_group_ptr) { +void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_group_ptr) { auto &group_ptr_ = *p_group_ptr; auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), - [array_interface] __device__(size_t i) { - return array_interface.GetElement(i, 0); - }); + [array_interface] __device__(size_t i) { return array_interface.operator()(i); }); dh::caching_device_vector flag(1); auto d_flag = dh::ToSpan(flag); auto d = SetDeviceToPtr(array_interface.data); dh::LaunchN(1, [=] __device__(size_t) { d_flag[0] = true; }); - dh::LaunchN(array_interface.num_rows - 1, [=] __device__(size_t i) { - if (array_interface.GetElement(i, 0) > - array_interface.GetElement(i + 1, 0)) { + dh::LaunchN(array_interface.Shape(0) - 1, [=] __device__(size_t i) { + if (array_interface.operator()(i) > array_interface.operator()(i + 1)) { d_flag[0] = false; } }); @@ -95,16 +90,16 @@ void CopyQidImpl(ArrayInterface array_interface, cudaMemcpyDeviceToHost)); CHECK(non_dec) << "`qid` must be sorted in increasing order along with data."; size_t bytes = 0; - dh::caching_device_vector out(array_interface.num_rows); - dh::caching_device_vector cnt(array_interface.num_rows); + dh::caching_device_vector out(array_interface.Shape(0)); + dh::caching_device_vector cnt(array_interface.Shape(0)); HostDeviceVector d_num_runs_out(1, 0, d); cub::DeviceRunLengthEncode::Encode( nullptr, bytes, it, out.begin(), cnt.begin(), - d_num_runs_out.DevicePointer(), array_interface.num_rows); + d_num_runs_out.DevicePointer(), array_interface.Shape(0)); dh::caching_device_vector tmp(bytes); cub::DeviceRunLengthEncode::Encode( tmp.data().get(), bytes, it, out.begin(), cnt.begin(), - d_num_runs_out.DevicePointer(), array_interface.num_rows); + d_num_runs_out.DevicePointer(), array_interface.Shape(0)); auto h_num_runs_out = d_num_runs_out.HostSpan()[0]; group_ptr_.clear(); @@ -115,77 +110,52 @@ void CopyQidImpl(ArrayInterface array_interface, thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out, group_ptr_.begin() + 1); } +} // namespace -namespace { -// thrust::all_of tries to copy lambda function. -struct LabelsCheck { - __device__ bool operator()(float y) { return ::isnan(y) || ::isinf(y); } -}; -struct WeightsCheck { - __device__ bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT -}; -} // anonymous namespace - -void ValidateQueryGroup(std::vector const &group_ptr_); - -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { - Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface array_interface(interface_str); - std::string key{c_key}; - - CHECK(!array_interface.valid.Data()) - << "Meta info " << key << " should be dense, found validity mask"; - if (array_interface.num_rows == 0) { - return; - } - +void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { + // multi-dim float info if (key == "base_margin") { - CopyInfoImpl(array_interface, &base_margin_); + CopyTensorInfoImpl(array, &base_margin_); return; } - - CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1) - << "MetaInfo: " << c_key << " has invalid shape"; - if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) || - (array_interface.num_cols == 0 && array_interface.num_rows == 1))) { - // Not an empty column, transform it. - array_interface.AsColumnVector(); + // uint info + if (key == "group") { + auto array_interface{ArrayInterface<1>(array)}; + CopyGroupInfoImpl(array_interface, &group_ptr_); + data::ValidateQueryGroup(group_ptr_); + return; + } else if (key == "qid") { + auto array_interface{ArrayInterface<1>(array)}; + CopyQidImpl(array_interface, &group_ptr_); + data::ValidateQueryGroup(group_ptr_); + return; } + // float info + linalg::Tensor t; + CopyTensorInfoImpl(array, &t); if (key == "label") { - CopyInfoImpl(array_interface, &labels_); + this->labels_ = std::move(*t.Data()); auto ptr = labels_.ConstDevicePointer(); - auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), - LabelsCheck{}); + auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), data::LabelsCheck{}); CHECK(valid) << "Label contains NaN, infinity or a value too large."; } else if (key == "weight") { - CopyInfoImpl(array_interface, &weights_); + this->weights_ = std::move(*t.Data()); auto ptr = weights_.ConstDevicePointer(); - auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), - WeightsCheck{}); + auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), data::WeightsCheck{}); CHECK(valid) << "Weights must be positive values."; - } else if (key == "group") { - CopyGroupInfoImpl(array_interface, &group_ptr_); - ValidateQueryGroup(group_ptr_); - return; - } else if (key == "qid") { - CopyQidImpl(array_interface, &group_ptr_); - return; } else if (key == "label_lower_bound") { - CopyInfoImpl(array_interface, &labels_lower_bound_); - return; + this->labels_lower_bound_ = std::move(*t.Data()); } else if (key == "label_upper_bound") { - CopyInfoImpl(array_interface, &labels_upper_bound_); - return; + this->labels_upper_bound_ = std::move(*t.Data()); } else if (key == "feature_weights") { - CopyInfoImpl(array_interface, &feature_weights); + this->feature_weights = std::move(*t.Data()); auto d_feature_weights = feature_weights.ConstDeviceSpan(); - auto valid = thrust::none_of( - thrust::device, d_feature_weights.data(), - d_feature_weights.data() + d_feature_weights.size(), WeightsCheck{}); + auto valid = + thrust::none_of(thrust::device, d_feature_weights.data(), + d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; - return; } else { - LOG(FATAL) << "Unknown metainfo: " << key; + LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 628878f319f1..d1bda280a7d5 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -20,7 +20,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { public: CudfAdapterBatch() = default; - CudfAdapterBatch(common::Span columns, size_t num_rows) + CudfAdapterBatch(common::Span> columns, size_t num_rows) : columns_(columns), num_rows_(num_rows) {} size_t Size() const { return num_rows_ * columns_.size(); } @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { size_t row_idx = idx / columns_.size(); auto const& column = columns_[column_idx]; float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) - ? column.GetElement(row_idx, 0) + ? column(row_idx) : std::numeric_limits::quiet_NaN(); return {row_idx, column_idx, value}; } @@ -38,7 +38,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } private: - common::Span columns_; + common::Span> columns_; size_t num_rows_; }; @@ -101,9 +101,9 @@ class CudfAdapter : public detail::SingleBatchDataIter { auto const& typestr = get(json_columns[0]["typestr"]); CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); - std::vector columns; - auto first_column = ArrayInterface(get(json_columns[0])); - num_rows_ = first_column.num_rows; + std::vector> columns; + auto first_column = ArrayInterface<1>(get(json_columns[0])); + num_rows_ = first_column.Shape(0); if (num_rows_ == 0) { return; } @@ -112,13 +112,12 @@ class CudfAdapter : public detail::SingleBatchDataIter { CHECK_NE(device_idx_, -1); dh::safe_cuda(cudaSetDevice(device_idx_)); for (auto& json_col : json_columns) { - auto column = ArrayInterface(get(json_col)); + auto column = ArrayInterface<1>(get(json_col)); columns.push_back(column); - CHECK_EQ(column.num_cols, 1); - num_rows_ = std::max(num_rows_, size_t(column.num_rows)); + num_rows_ = std::max(num_rows_, size_t(column.Shape(0))); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; - CHECK_EQ(num_rows_, column.num_rows) + CHECK_EQ(num_rows_, column.Shape(0)) << "All columns should have same number of rows."; } columns_ = columns; @@ -135,7 +134,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { private: CudfAdapterBatch batch_; - dh::device_vector columns_; + dh::device_vector> columns_; size_t num_rows_{0}; int device_idx_; }; @@ -143,23 +142,23 @@ class CudfAdapter : public detail::SingleBatchDataIter { class CupyAdapterBatch : public detail::NoMetaInfo { public: CupyAdapterBatch() = default; - explicit CupyAdapterBatch(ArrayInterface array_interface) + explicit CupyAdapterBatch(ArrayInterface<2> array_interface) : array_interface_(std::move(array_interface)) {} size_t Size() const { - return array_interface_.num_rows * array_interface_.num_cols; + return array_interface_.Shape(0) * array_interface_.Shape(1); } __device__ COOTuple GetElement(size_t idx) const { - size_t column_idx = idx % array_interface_.num_cols; - size_t row_idx = idx / array_interface_.num_cols; - float value = array_interface_.GetElement(row_idx, column_idx); + size_t column_idx = idx % array_interface_.Shape(1); + size_t row_idx = idx / array_interface_.Shape(1); + float value = array_interface_(row_idx, column_idx); return {row_idx, column_idx, value}; } - XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.num_rows; } - XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.num_cols; } + XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); } + XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); } private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; }; class CupyAdapter : public detail::SingleBatchDataIter { @@ -167,9 +166,9 @@ class CupyAdapter : public detail::SingleBatchDataIter { explicit CupyAdapter(std::string cuda_interface_str) { Json json_array_interface = Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()}); - array_interface_ = ArrayInterface(get(json_array_interface), false); + array_interface_ = ArrayInterface<2>(get(json_array_interface)); batch_ = CupyAdapterBatch(array_interface_); - if (array_interface_.num_rows == 0) { + if (array_interface_.Shape(0) == 0) { return; } device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); @@ -177,12 +176,12 @@ class CupyAdapter : public detail::SingleBatchDataIter { } const CupyAdapterBatch& Value() const override { return batch_; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumColumns() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumColumns() const { return array_interface_.Shape(1); } int32_t DeviceIdx() const { return device_idx_; } private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; CupyAdapterBatch batch_; int32_t device_idx_ {-1}; }; diff --git a/src/data/file_iterator.h b/src/data/file_iterator.h index 6d6adb62b008..70a5d51c30b9 100644 --- a/src/data/file_iterator.h +++ b/src/data/file_iterator.h @@ -12,6 +12,7 @@ #include "dmlc/data.h" #include "xgboost/c_api.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "array_interface.h" namespace xgboost { @@ -58,16 +59,14 @@ class FileIterator { CHECK(parser_); if (parser_->Next()) { row_block_ = parser_->Value(); + using linalg::MakeVec; - indptr_ = MakeArrayInterface(row_block_.offset, row_block_.size + 1); - values_ = MakeArrayInterface(row_block_.value, - row_block_.offset[row_block_.size]); - indices_ = MakeArrayInterface(row_block_.index, - row_block_.offset[row_block_.size]); + indptr_ = MakeVec(row_block_.offset, row_block_.size + 1).ArrayInterfaceStr(); + values_ = MakeVec(row_block_.value, row_block_.offset[row_block_.size]).ArrayInterfaceStr(); + indices_ = MakeVec(row_block_.index, row_block_.offset[row_block_.size]).ArrayInterfaceStr(); - size_t n_columns = *std::max_element( - row_block_.index, - row_block_.index + row_block_.offset[row_block_.size]); + size_t n_columns = *std::max_element(row_block_.index, + row_block_.index + row_block_.offset[row_block_.size]); // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore // this condition and just add 1 to n_columns n_columns += 1; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 44a8a3f8fe7c..e83559d3958a 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -137,9 +137,10 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - auto& base_margin = info_.base_margin_.HostVector(); - base_margin.insert(base_margin.end(), batch.BaseMargin(), - batch.BaseMargin() + batch.Size()); + info_.base_margin_ = linalg::Tensor{batch.BaseMargin(), + batch.BaseMargin() + batch.Size(), + {batch.Size()}, + GenericParameter::kCpuId}; } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); diff --git a/src/data/validation.h b/src/data/validation.h new file mode 100644 index 000000000000..6d3701114886 --- /dev/null +++ b/src/data/validation.h @@ -0,0 +1,40 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_DATA_VALIDATION_H_ +#define XGBOOST_DATA_VALIDATION_H_ +#include +#include + +#include "xgboost/base.h" +#include "xgboost/logging.h" + +namespace xgboost { +namespace data { +struct LabelsCheck { + XGBOOST_DEVICE bool operator()(float y) { +#if defined(__CUDA_ARCH__) + return ::isnan(y) || ::isinf(y); +#else + return std::isnan(y) || std::isinf(y); +#endif + } +}; + +struct WeightsCheck { + XGBOOST_DEVICE bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT +}; + +inline void ValidateQueryGroup(std::vector const &group_ptr_) { + bool valid_query_group = true; + for (size_t i = 1; i < group_ptr_.size(); ++i) { + valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1]; + if (XGBOOST_EXPECT(!valid_query_group, false)) { + break; + } + } + CHECK(valid_query_group) << "Invalid group structure."; +} +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_VALIDATION_H_ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index e5f2a457214c..e5f5916211cc 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -178,7 +178,7 @@ class GBLinear : public GradientBooster { unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override { model_.LazyInitModel(); LinearCheckLayer(layer_begin, layer_end); - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); const int ngroup = model_.learner_model_param->num_output_group; const size_t ncolumns = model_.learner_model_param->num_feature + 1; // allocate space for (#features + bias) times #groups times #rows @@ -203,9 +203,9 @@ class GBLinear : public GradientBooster { p_contribs[ins.index] = ins.fvalue * model_[ins.index][gid]; } // add base margin to BIAS - p_contribs[ncolumns - 1] = model_.Bias()[gid] + - ((base_margin.size() != 0) ? base_margin[row_idx * ngroup + gid] : - learner_model_param_->base_score); + p_contribs[ncolumns - 1] = + model_.Bias()[gid] + ((base_margin.Size() != 0) ? base_margin(row_idx, gid) + : learner_model_param_->base_score); } }); } @@ -270,7 +270,7 @@ class GBLinear : public GradientBooster { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); std::vector &preds = *out_preds; - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); @@ -280,16 +280,15 @@ class GBLinear : public GradientBooster { // k is number of group // parallel over local batch const auto nsize = static_cast(batch.Size()); - if (base_margin.size() != 0) { - CHECK_EQ(base_margin.size(), nsize * ngroup); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Size(), nsize * ngroup); } common::ParallelFor(nsize, [&](omp_ulong i) { const size_t ridx = page.base_rowid + i; // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float margin = - (base_margin.size() != 0) ? - base_margin[ridx * ngroup + gid] : learner_model_param_->base_score; + float margin = + (base_margin.Size() != 0) ? base_margin(ridx, gid) : learner_model_param_->base_score; this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); } }); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index d581f64a1d56..92797235d16b 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -282,27 +282,6 @@ class CPUPredictor : public Predictor { } } - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n = model.learner_model_param->num_output_group * info.num_row_; - const auto& base_margin = info.base_margin_.HostVector(); - out_preds->Resize(n); - std::vector& out_preds_h = out_preds->HostVector(); - if (base_margin.empty()) { - std::fill(out_preds_h.begin(), out_preds_h.end(), - model.learner_model_param->base_score); - } else { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); - } - } - public: explicit CPUPredictor(GenericParameter const* generic_param) : Predictor::Predictor{generic_param} {} @@ -456,7 +435,7 @@ class CPUPredictor : public Predictor { common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); - const std::vector& base_margin = info.base_margin_.HostVector(); + auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); // start collecting the contributions for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); @@ -496,8 +475,9 @@ class CPUPredictor : public Predictor { } feats.Drop(page[i]); // add base margin to BIAS - if (base_margin.size() != 0) { - p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), ngroup); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); } else { p_contribs[ncolumns - 1] += model.learner_model_param->base_score; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 51674237e973..025a1c495c52 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -854,8 +854,8 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + p_fmat->Info().base_margin_.Data()->SetDevice(generic_param_->gpu_id); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; dh::LaunchN( p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, @@ -913,8 +913,8 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + p_fmat->Info().base_margin_.Data()->SetDevice(generic_param_->gpu_id); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; size_t n_features = model.learner_model_param->num_feature; dh::LaunchN( @@ -928,27 +928,6 @@ class GPUPredictor : public xgboost::Predictor { }); } - protected: - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - size_t n_classes = model.learner_model_param->num_output_group; - size_t n = n_classes * info.num_row_; - const HostDeviceVector& base_margin = info.base_margin_; - out_preds->SetDevice(generic_param_->gpu_id); - out_preds->Resize(n); - if (base_margin.Size() != 0) { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.Size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - out_preds->Copy(base_margin); - } else { - out_preds->Fill(model.learner_model_param->base_score); - } - } - void PredictInstance(const SparsePage::Inst&, std::vector*, const gbm::GBTreeModel&, unsigned) const override { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 9aa18b19ce30..e9e10a632a7f 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 by Contributors + * Copyright 2017-2021 by Contributors */ #include #include @@ -8,6 +8,8 @@ #include "xgboost/data.h" #include "xgboost/generic_parameters.h" +#include "../gbm/gbtree.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc @@ -58,6 +60,33 @@ Predictor* Predictor::Create( auto p_predictor = (e->body)(generic_param); return p_predictor; } + +void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n_samples, + bst_group_t n_groups) { + // FIXME: Bindings other than Python doesn't have shape. + std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) + + ", " + std::to_string(n_groups) + ")"}; + CHECK_EQ(margin.Shape(0), n_samples) << expected; + CHECK_EQ(margin.Shape(1), n_groups) << expected; +} + +void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model) const { + CHECK_NE(model.learner_model_param->num_output_group, 0); + size_t n_classes = model.learner_model_param->num_output_group; + size_t n = n_classes * info.num_row_; + const HostDeviceVector* base_margin = info.base_margin_.Data(); + if (generic_param_->gpu_id >= 0) { + out_preds->SetDevice(generic_param_->gpu_id); + } + out_preds->Resize(n); + if (base_margin->Size() != 0) { + ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); + out_preds->Copy(*base_margin); + } else { + out_preds->Fill(model.learner_model_param->base_score); + } +} } // namespace xgboost namespace xgboost { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index c7fc4972d3ee..49952d202706 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -57,7 +57,7 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, Json(Integer(reinterpret_cast(x.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; - array_interface["version"] = Integer(static_cast(1)); + array_interface["version"] = 3; array_interface["typestr"] = String(" indices; size_t n_features = 100, n_samples = 10; RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices); - auto indptr_arr = MakeArrayInterface(indptr.HostPointer(), indptr.Size()); - auto values_arr = MakeArrayInterface(values.HostPointer(), values.Size()); - auto indices_arr = MakeArrayInterface(indices.HostPointer(), indices.Size()); + using linalg::MakeVec; + auto indptr_arr = MakeVec(indptr.HostPointer(), indptr.Size()).ArrayInterfaceStr(); + auto values_arr = MakeVec(values.HostPointer(), values.Size()).ArrayInterfaceStr(); + auto indices_arr = MakeVec(indices.HostPointer(), indices.Size()).ArrayInterfaceStr(); auto adapter = data::CSRArrayAdapter( StringView{indptr_arr.c_str(), indptr_arr.size()}, StringView{values_arr.c_str(), values_arr.size()}, diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc index 875858855ed5..f438593b42a4 100644 --- a/tests/cpp/data/test_array_interface.cc +++ b/tests/cpp/data/test_array_interface.cc @@ -11,21 +11,22 @@ TEST(ArrayInterface, Initialize) { size_t constexpr kRows = 10, kCols = 10; HostDeviceVector storage; auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - auto arr_interface = ArrayInterface(array); - ASSERT_EQ(arr_interface.num_rows, kRows); - ASSERT_EQ(arr_interface.num_cols, kCols); + auto arr_interface = ArrayInterface<2>(StringView{array}); + ASSERT_EQ(arr_interface.Shape(0), kRows); + ASSERT_EQ(arr_interface.Shape(1), kCols); ASSERT_EQ(arr_interface.data, storage.ConstHostPointer()); ASSERT_EQ(arr_interface.ElementSize(), 4); - ASSERT_EQ(arr_interface.type, ArrayInterface::kF4); + ASSERT_EQ(arr_interface.type, ArrayInterfaceHandler::kF4); HostDeviceVector u64_storage(storage.Size()); - std::string u64_arr_str; - Json::Dump(GetArrayInterface(&u64_storage, kRows, kCols), &u64_arr_str); + std::string u64_arr_str{linalg::TensorView{ + u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId} + .ArrayInterfaceStr()}; std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(), u64_storage.HostSpan().begin()); - auto u64_arr = ArrayInterface{u64_arr_str}; + auto u64_arr = ArrayInterface<2>{u64_arr_str}; ASSERT_EQ(u64_arr.ElementSize(), 8); - ASSERT_EQ(u64_arr.type, ArrayInterface::kU8); + ASSERT_EQ(u64_arr.type, ArrayInterfaceHandler::kU8); } TEST(ArrayInterface, Error) { @@ -38,23 +39,22 @@ TEST(ArrayInterface, Error) { Json(Boolean(false))}; auto const& column_obj = get(column); - std::pair shape{kRows, kCols}; std::string typestr{"(1)); + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); + column["version"] = 3; // missing data - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape), + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); column["data"] = j_data; // missing typestr - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape), + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); column["typestr"] = String(" storage; @@ -63,22 +63,41 @@ TEST(ArrayInterface, Error) { Json(Integer(reinterpret_cast(storage.ConstHostPointer()))), Json(Boolean(false))}; column["data"] = j_data; - EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape)); + EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n)); } TEST(ArrayInterface, GetElement) { size_t kRows = 4, kCols = 2; HostDeviceVector storage; auto intefrace_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - ArrayInterface array_interface{intefrace_str}; + ArrayInterface<2> array_interface{intefrace_str}; auto const& h_storage = storage.ConstHostVector(); for (size_t i = 0; i < kRows; ++i) { for (size_t j = 0; j < kCols; ++j) { - float v0 = array_interface.GetElement(i, j); + float v0 = array_interface(i, j); float v1 = h_storage.at(i * kCols + j); ASSERT_EQ(v0, v1); } } } + +TEST(ArrayInterface, TrivialDim) { + size_t kRows{1000}, kCols = 1; + HostDeviceVector storage; + auto interface_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + { + ArrayInterface<1> arr_i{interface_str}; + ASSERT_EQ(arr_i.n, kRows); + ASSERT_EQ(arr_i.Shape(0), kRows); + } + + std::swap(kRows, kCols); + interface_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + { + ArrayInterface<1> arr_i{interface_str}; + ASSERT_EQ(arr_i.n, kCols); + ASSERT_EQ(arr_i.Shape(0), kCols); + } +} } // namespace xgboost diff --git a/tests/cpp/data/test_array_interface.cu b/tests/cpp/data/test_array_interface.cu index 75923e74ba1a..c8e07852534b 100644 --- a/tests/cpp/data/test_array_interface.cu +++ b/tests/cpp/data/test_array_interface.cu @@ -32,11 +32,24 @@ TEST(ArrayInterface, Stream) { dh::caching_device_vector out(1, 0); uint64_t dur = 1e9; dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur); - ArrayInterface arr(arr_str); + ArrayInterface<2> arr(arr_str); auto t = out[0]; CHECK_GE(t, dur); cudaStreamDestroy(stream); } + +TEST(ArrayInterface, Ptr) { + std::vector h_data(10); + ASSERT_FALSE(ArrayInterfaceHandler::IsCudaPtr(h_data.data())); + dh::safe_cuda(cudaGetLastError()); + + dh::device_vector d_data(10); + ASSERT_TRUE(ArrayInterfaceHandler::IsCudaPtr(d_data.data().get())); + dh::safe_cuda(cudaGetLastError()); + + ASSERT_FALSE(ArrayInterfaceHandler::IsCudaPtr(nullptr)); + dh::safe_cuda(cudaGetLastError()); +} } // namespace xgboost diff --git a/tests/cpp/data/test_array_interface.h b/tests/cpp/data/test_array_interface.h index 7872a9507aa5..78bce76f53e7 100644 --- a/tests/cpp/data/test_array_interface.h +++ b/tests/cpp/data/test_array_interface.h @@ -19,6 +19,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows, std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); + column["stream"] = nullptr; d_data.resize(kRows); thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f); @@ -30,7 +31,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows, Json(Boolean(false))}; column["data"] = j_data; - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); return column; } @@ -43,6 +44,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows, std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); + column["stream"] = nullptr; d_data.resize(kRows); for (size_t i = 0; i < d_data.size(); ++i) { @@ -56,7 +58,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows, Json(Boolean(false))}; column["data"] = j_data; - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); return column; } @@ -75,9 +77,9 @@ Json Generate2dArrayInterface(int rows, int cols, std::string typestr, Json(Integer(reinterpret_cast(data.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; - array_interface["version"] = Integer(static_cast(1)); + array_interface["version"] = 3; array_interface["typestr"] = String(typestr); + array_interface["stream"] = nullptr; return array_interface; } - } // namespace xgboost diff --git a/tests/cpp/data/test_iterative_device_dmatrix.cu b/tests/cpp/data/test_iterative_device_dmatrix.cu index cb64a3b5cdb2..27f6b0b3ffe9 100644 --- a/tests/cpp/data/test_iterative_device_dmatrix.cu +++ b/tests/cpp/data/test_iterative_device_dmatrix.cu @@ -103,7 +103,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) { auto j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface loaded {get(j_interface)}; + ArrayInterface<2> loaded {get(j_interface)}; std::vector h_data(cols * rows); common::Span s_data{static_cast(loaded.data), cols * rows}; dh::CopyDeviceSpanToVector(&h_data, s_data); @@ -128,7 +128,7 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) { std::string interface_str = iter.AsArray(); auto j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface loaded {get(j_interface)}; + ArrayInterface<2> loaded {get(j_interface)}; std::vector h_data(cols * rows); common::Span s_data{static_cast(loaded.data), cols * rows}; dh::CopyDeviceSpanToVector(&h_data, s_data); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index bb5452a56d28..4f379f4fdd25 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -1,4 +1,6 @@ -// Copyright 2016-2020 by Contributors +// Copyright 2016-2021 by Contributors +#include "test_metainfo.h" + #include #include #include @@ -122,7 +124,10 @@ TEST(MetaInfo, SaveLoadBinary) { EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); - EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector()); + + auto orig_margin = info.base_margin_.View(xgboost::GenericParameter::kCpuId); + auto read_margin = inforead.base_margin_.View(xgboost::GenericParameter::kCpuId); + EXPECT_TRUE(std::equal(orig_margin.cbegin(), orig_margin.cend(), read_margin.cbegin())); EXPECT_EQ(inforead.feature_type_names.size(), kCols); EXPECT_EQ(inforead.feature_types.Size(), kCols); @@ -254,10 +259,10 @@ TEST(MetaInfo, Validate) { xgboost::HostDeviceVector d_groups{groups}; d_groups.SetDevice(0); d_groups.DevicePointer(); // pull to device - auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1); - std::string arr_interface_str; - xgboost::Json::Dump(arr_interface, &arr_interface_str); - EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error); + std::string arr_interface_str{ + xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0) + .ArrayInterfaceStr()}; + EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } @@ -292,3 +297,7 @@ TEST(MetaInfo, HostExtend) { ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i); } } + +namespace xgboost { +TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); } +} // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 090374b913d6..d57e9ceda807 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -1,12 +1,14 @@ -/*! Copyright 2019 by Contributors */ - +/*! Copyright 2019-2021 by XGBoost Contributors */ #include #include #include +#include #include #include "test_array_interface.h" #include "../../../src/common/device_helpers.cuh" +#include "test_metainfo.h" + namespace xgboost { template @@ -23,7 +25,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); auto p_d_data = d_data.data().get(); @@ -31,6 +33,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons Json(Integer(reinterpret_cast(p_d_data))), Json(Boolean(false))}; column["data"] = j_data; + column["stream"] = nullptr; Json array(std::vector{column}); std::string str; @@ -49,6 +52,7 @@ TEST(MetaInfo, FromInterface) { info.SetInfo("label", str.c_str()); auto const& h_label = info.labels_.HostVector(); + ASSERT_EQ(h_label.size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { ASSERT_EQ(h_label[i], d_data[i]); } @@ -60,9 +64,10 @@ TEST(MetaInfo, FromInterface) { } info.SetInfo("base_margin", str.c_str()); - auto const& h_base_margin = info.base_margin_.HostVector(); + auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId); + ASSERT_EQ(h_base_margin.Size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { - ASSERT_EQ(h_base_margin[i], d_data[i]); + ASSERT_EQ(h_base_margin(i), d_data[i]); } thrust::device_vector d_group_data; @@ -76,6 +81,10 @@ TEST(MetaInfo, FromInterface) { EXPECT_EQ(info.group_ptr_, expected_group_ptr); } +TEST(MetaInfo, GPUStridedData) { + TestMetaInfoStridedData(0); +} + TEST(MetaInfo, Group) { cudaSetDevice(0); MetaInfo info; diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h new file mode 100644 index 000000000000..67da633d4be5 --- /dev/null +++ b/tests/cpp/data/test_metainfo.h @@ -0,0 +1,82 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#include +#include +#include +#include + +#include +#include "../../../src/data/array_interface.h" +#include "../../../src/common/linalg_op.h" + +namespace xgboost { +inline void TestMetaInfoStridedData(int32_t device) { + MetaInfo info; + { + // label + HostDeviceVector labels; + labels.Resize(64); + auto& h_labels = labels.HostVector(); + std::iota(h_labels.begin(), h_labels.end(), 0.0f); + bool is_gpu = device >= 0; + if (is_gpu) { + labels.SetDevice(0); + } + + auto t = linalg::TensorView{ + is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device}; + auto s = t.Slice(linalg::All(), 0); + + auto str = s.ArrayInterfaceStr(); + ASSERT_EQ(s.Size(), 32); + + info.SetInfo("label", StringView{str}); + auto const& h_result = info.labels_.HostVector(); + ASSERT_EQ(h_result.size(), 32); + + for (auto v : h_result) { + ASSERT_EQ(static_cast(v) % 2, 0); + } + } + { + // qid + linalg::Tensor qid; + qid.Reshape(32, 2); + auto& h_qid = qid.Data()->HostVector(); + std::iota(h_qid.begin(), h_qid.end(), 0); + auto s = qid.View(device).Slice(linalg::All(), 0); + auto str = s.ArrayInterfaceStr(); + info.SetInfo("qid", StringView{str}); + auto const& h_result = info.group_ptr_; + ASSERT_EQ(h_result.size(), s.Size() + 1); + } + { + // base margin + linalg::Tensor base_margin; + base_margin.Reshape(4, 3, 2, 3); + auto& h_margin = base_margin.Data()->HostVector(); + std::iota(h_margin.begin(), h_margin.end(), 0.0); + auto t_margin = base_margin.View(device).Slice(linalg::All(), linalg::All(), 0, linalg::All()); + ASSERT_EQ(t_margin.Shape().size(), 3); + + info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()}); + auto const& h_result = info.base_margin_.View(-1); + ASSERT_EQ(h_result.Shape().size(), 3); + auto in_margin = base_margin.View(-1); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + auto tup = linalg::UnravelIndex(i, h_result.Shape()); + auto i0 = std::get<0>(tup); + auto i1 = std::get<1>(tup); + auto i2 = std::get<2>(tup); + // Sliced at 3^th dimension. + auto v_1 = in_margin(i0, i1, 0, i2); + CHECK_EQ(v_0, v_1); + return v_0; + }); + } +} +} // namespace xgboost +#endif // XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index f777b00e244c..c25e877079d2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -253,8 +253,8 @@ TEST(SimpleDMatrix, Slice) { std::iota(lower.begin(), lower.end(), 0.0f); std::iota(upper.begin(), upper.end(), 1.0f); - auto& margin = p_m->Info().base_margin_.HostVector(); - margin.resize(kRows * kClasses); + auto& margin = p_m->Info().base_margin_; + margin = linalg::Tensor{{kRows, kClasses}, GenericParameter::kCpuId}; std::array ridxs {1, 3, 5}; std::unique_ptr out { p_m->Slice(ridxs) }; @@ -284,10 +284,10 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx), out->Info().weights_.HostVector().at(i)); - auto& out_margin = out->Info().base_margin_.HostVector(); + auto out_margin = out->Info().base_margin_.View(GenericParameter::kCpuId); + auto in_margin = margin.View(GenericParameter::kCpuId); for (size_t j = 0; j < kClasses; ++j) { - auto in_beg = ridx * kClasses; - ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j)); + ASSERT_EQ(out_margin(i, j), in_margin(ridx, j)); } } } diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index d74f5b150087..19f13b1fddfd 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -122,13 +122,13 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) { col["data"] = j_data; std::vector j_shape{Json(Integer(static_cast(kRows)))}; col["shape"] = Array(j_shape); - col["version"] = Integer(static_cast(1)); + col["version"] = 3; col["typestr"] = String("(1)); + j_mask["version"] = 3; auto& mask_storage = column_bitfields[i]; mask_storage.resize(16); // 16 bytes @@ -220,7 +220,7 @@ TEST(SimpleCSRSource, FromColumnarSparse) { for (size_t c = 0; c < kCols; ++c) { auto& column = j_columns[c]; column = Object(); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(" j_data { @@ -229,12 +229,12 @@ TEST(SimpleCSRSource, FromColumnarSparse) { column["data"] = j_data; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String("(1)); + j_mask["version"] = 3; j_mask["data"] = std::vector{ Json(Integer(reinterpret_cast(column_bitfields[c].data().get()))), Json(Boolean(false))}; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index db4454c9eb1e..0906d9ed87df 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -228,6 +228,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( if (device_ >= 0) { array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer() + offset)); + array_interface["stream"] = Null{}; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->HostPointer() + offset)); @@ -240,7 +241,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( array_interface["shape"][1] = cols_; array_interface["typestr"] = String(" *storage, size_t rows, size_t cols) { if (storage->DeviceCanRead()) { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstDevicePointer())); + array_interface["stream"] = nullptr; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstHostPointer())); @@ -198,9 +199,9 @@ Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { array_interface["shape"][0] = rows; array_interface["shape"][1] = cols; - char t = ArrayInterfaceHandler::TypeChar(); + char t = linalg::detail::ArrayInterfaceHandler::TypeChar(); array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T))); - array_interface["version"] = 1; + array_interface["version"] = 3; return array_interface; } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index ad1083f9161b..b36df742da4f 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -108,7 +108,9 @@ TEST(GPUPredictor, ExternalMemoryTest) { dmats.push_back(CreateSparsePageDMatrix(8000)); for (const auto& dmat: dmats) { - dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); + dmat->Info().base_margin_ = + linalg::Tensor{{dmat->Info().num_row_, static_cast(n_classes)}, 0}; + dmat->Info().base_margin_.Data()->Fill(0.5); PredictionCacheEntry out_predictions; gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 1b5ad266d3d9..9dcaa1c0d2c1 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -141,6 +141,7 @@ def test_slice(self): # base margin is per-class in multi-class classifier base_margin = rng.randn(100, 3).astype(np.float32) d.set_base_margin(base_margin) + np.testing.assert_allclose(d.get_base_margin().reshape(100, 3), base_margin) ridxs = [1, 2, 3, 4, 5, 6] sliced = d.slice(ridxs) @@ -154,7 +155,7 @@ def test_slice(self): # Slicing a DMatrix results into a DMatrix that's equivalent to a DMatrix that's # constructed from the corresponding NumPy slice d2 = xgb.DMatrix(X[1:7, :], y[1:7]) - d2.set_base_margin(base_margin[1:7, :].flatten()) + d2.set_base_margin(base_margin[1:7, :]) eval_res = {} _ = xgb.train( {'num_class': 3, 'objective': 'multi:softprob', @@ -280,7 +281,7 @@ def test_feature_weights(self): m.set_info(feature_weights=fw) np.testing.assert_allclose(fw, m.get_float_info('feature_weights')) # Handle empty - m.set_info(feature_weights=np.empty((0, 0))) + m.set_info(feature_weights=np.empty((0, ))) assert m.get_float_info('feature_weights').shape[0] == 0