diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst index 6d407ba129e9..b4880803cb0c 100644 --- a/doc/contrib/coding_guide.rst +++ b/doc/contrib/coding_guide.rst @@ -134,3 +134,49 @@ Similarly, if you want to exclude C++ source from linting: cd /path/to/xgboost/ python3 tests/ci_build/tidy.py --cpp=0 +********************************** +Guide for handling user input data +********************************** + +This is an in-comprehensive guide for handling user input data. XGBoost has wide verity +of native supported data structures, mostly come from higher level language bindings. The +inputs ranges from basic contiguous 1 dimension memory buffer to more sophisticated data +structures like columnar data with validity mask. Raw input data can be used in 2 places, +firstly it's the construction of various ``DMatrix``, secondly it's the in-place +prediction. For plain memory buffer, there's not much to discuss since it's just a +pointer with a size. But for general n-dimension array and columnar data, there are many +subtleties. XGBoost has 3 different data structures for handling optionally masked arrays +(tensors), for consuming user inputs ``ArrayInterface`` should be chosen. There are many +existing functions that accept only plain pointer due to legacy reasons (XGBoost started +as a much simpler library and didn't care about memory usage that much back then). The +``ArrayInterface`` is a in memory representation of ``__array_interface__`` protocol +defined by numpy or the ``__cuda_array_interface__`` defined by numba. Following is a +check list of things to have in mind when accepting related user inputs: + +- [ ] Is it strided? (identified by the ``strides`` field) +- [ ] If it's a vector, is it row vector or column vector? (Identified by both ``shape`` + and ``strides``). +- [ ] Is the data type supported? Half type and 128 integer types should be converted + before going into XGBoost. +- [ ] Does it have higher than 1 dimension? (identified by ``shape`` field) +- [ ] Are some of dimensions trivial? (shape[dim] <= 1) +- [ ] Does it have mask? (identified by ``mask`` field) +- [ ] Can the mask be broadcasted? (unsupported at the moment) +- [ ] Is it on CUDA memory? (identified by ``data`` field, and optionally ``stream``) + +Most of the checks are handled by the ``ArrayInterface`` during construction, except for +the data type issue since it doesn't know how to cast such pointers with C builtin types. +But for safety reason one should still try to write related tests for the all items. The +data type issue should be taken care of in language binding for each of the specific data +input. For single-chunk columnar format, it's just a masked array for each column so it +should be treated uniformly as normal array. For input predictor ``X``, we have adapters +for each type of input. Some are composition of the others. For instance, CSR matrix has 3 +potentially strided arrays for ``indptr``, ``indices`` and ``values``. No assumption +should be made to these components (all the check boxes should be considered). Slicing row +of CSR matrix should calculate the offset of each field based on respective strides. + +For meta info like labels, which is growing both in size and complexity, we accept only +masked array at the moment (no specialized adapter). One should be careful about the +input data shape. For base margin it can be 2 dim or higher if we have multiple targets in +the future. The getters in ``DMatrix`` returns only 1 dimension flatten vectors at the +moment, which can be improved in the future when it's needed. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 8d929593c1a4..9d657872e26b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -249,7 +249,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const* json_config, DMatrixHandle *out); -/* +/** * ========================== Begin data callback APIs ========================= * * Short notes for data callback @@ -258,9 +258,9 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated * input, and concatenate them into 1 final big CSR. The related functions are: * - * - XGBCallbackSetData - * - XGBCallbackDataIterNext - * - XGDMatrixCreateFromDataIter + * - \ref XGBCallbackSetData + * - \ref XGBCallbackDataIterNext + * - \ref XGDMatrixCreateFromDataIter * * Another set is used by external data iterator. It accept foreign data iterators as * callbacks. There are 2 different senarios where users might want to pass in callbacks @@ -276,17 +276,17 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, * Related functions are: * * # Factory functions - * - `XGDMatrixCreateFromCallback` for external memory - * - `XGDeviceQuantileDMatrixCreateFromCallback` for quantile DMatrix + * - \ref XGDMatrixCreateFromCallback for external memory + * - \ref XGDeviceQuantileDMatrixCreateFromCallback for quantile DMatrix * * # Proxy that callers can use to pass data to XGBoost - * - XGProxyDMatrixCreate - * - XGDMatrixCallbackNext - * - DataIterResetCallback - * - XGProxyDMatrixSetDataCudaArrayInterface - * - XGProxyDMatrixSetDataCudaColumnar - * - XGProxyDMatrixSetDataDense - * - XGProxyDMatrixSetDataCSR + * - \ref XGProxyDMatrixCreate + * - \ref XGDMatrixCallbackNext + * - \ref DataIterResetCallback + * - \ref XGProxyDMatrixSetDataCudaArrayInterface + * - \ref XGProxyDMatrixSetDataCudaColumnar + * - \ref XGProxyDMatrixSetDataDense + * - \ref XGProxyDMatrixSetDataCSR * - ... (data setters) */ @@ -411,7 +411,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. * - nthread (optional): Number of threads used for initializing DMatrix. * - * \param out The created external memory DMatrix + * \param[out] out The created external memory DMatrix * * \return 0 when success, -1 when failure happens */ @@ -605,7 +605,8 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, * char const* feat_names [] {"feat_0", "feat_1"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_name", feat_names, 2); * - * // i for integer, q for quantitive. Similarly "int" and "float" are also recognized. + * // i for integer, q for quantitive, c for categorical. Similarly "int" and "float" + * // are also recognized. * char const* feat_types [] {"i", "q"}; * XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, 2); * diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 1678f4b1f4f1..90cbc1373713 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -11,12 +11,14 @@ #include #include #include -#include #include +#include +#include +#include +#include #include #include -#include #include #include #include @@ -45,7 +47,7 @@ enum class FeatureType : uint8_t { class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 11; + static constexpr uint64_t kNumField = 12; /*! \brief number of rows in the data */ uint64_t num_row_{0}; // NOLINT @@ -67,7 +69,7 @@ class MetaInfo { * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from. */ - HostDeviceVector base_margin_; // NOLINT + linalg::Tensor base_margin_; // NOLINT /*! * \brief lower bound of the label, to be used for survival analysis (censored regression) */ @@ -157,7 +159,8 @@ class MetaInfo { * * Right now only 1 column is permitted. */ - void SetInfo(const char* key, std::string const& interface_str); + + void SetInfo(StringView key, StringView interface_str); void GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const; @@ -179,6 +182,9 @@ class MetaInfo { void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); private: + void SetInfoFromHost(StringView key, Json arr); + void SetInfoFromCUDA(StringView key, Json arr); + /*! \brief argsort of labels */ mutable std::vector label_order_cache_; }; @@ -477,7 +483,7 @@ class DMatrix { this->Info().SetInfo(key, dptr, dtype, num); } virtual void SetInfo(const char* key, std::string const& interface_str) { - this->Info().SetInfo(key, interface_str); + this->Info().SetInfo(key, StringView{interface_str}); } /*! \brief meta information of the dataset */ virtual const MetaInfo& Info() const = 0; diff --git a/include/xgboost/intrusive_ptr.h b/include/xgboost/intrusive_ptr.h index 1c58704c4ca3..df7ae30213d6 100644 --- a/include/xgboost/intrusive_ptr.h +++ b/include/xgboost/intrusive_ptr.h @@ -19,7 +19,7 @@ namespace xgboost { */ class IntrusivePtrCell { private: - std::atomic count_; + std::atomic count_ {0}; template friend class IntrusivePtr; std::int32_t IncRef() noexcept { @@ -31,7 +31,7 @@ class IntrusivePtrCell { bool IsZero() const { return Count() == 0; } public: - IntrusivePtrCell() noexcept : count_{0} {} + IntrusivePtrCell() noexcept = default; int32_t Count() const { return count_.load(std::memory_order_relaxed); } }; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index dd87c601f978..859669a47ada 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -128,7 +128,7 @@ class Predictor { */ virtual void InitOutPredictions(const MetaInfo &info, HostDeviceVector *out_predt, - const gbm::GBTreeModel &model) const = 0; + const gbm::GBTreeModel &model) const; /** * \brief Generate batch predictions for a given feature matrix. May use diff --git a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu index de7a1fc41495..4ecf8b0f1da1 100644 --- a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu +++ b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu @@ -35,13 +35,12 @@ template T CheckJvmCall(T const &v, JNIEnv *jenv) { } template -void CopyColumnMask(xgboost::ArrayInterface const &interface, +void CopyColumnMask(xgboost::ArrayInterface<1> const &interface, std::vector const &columns, cudaMemcpyKind kind, size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) { auto &mask = *p_mask; auto &out = *p_out; - auto size = sizeof(typename VCont::value_type) * interface.num_rows * - interface.num_cols; + auto size = sizeof(typename VCont::value_type) * interface.n; mask.resize(size); CHECK(RawPtr(mask)); CHECK(size); @@ -67,11 +66,11 @@ void CopyColumnMask(xgboost::ArrayInterface const &interface, LOG(FATAL) << "Invalid shape of mask"; } out["mask"]["typestr"] = String(" -void CopyInterface(std::vector &interface_arr, +void CopyInterface(std::vector> &interface_arr, std::vector const &columns, cudaMemcpyKind kind, std::vector *p_data, std::vector *p_mask, std::vector *p_out, cudaStream_t stream) { @@ -81,7 +80,7 @@ void CopyInterface(std::vector &interface_arr, for (size_t c = 0; c < interface_arr.size(); ++c) { auto &interface = interface_arr.at(c); size_t element_size = interface.ElementSize(); - size_t size = element_size * interface.num_rows * interface.num_cols; + size_t size = element_size * interface.n; auto &data = (*p_data)[c]; auto &mask = (*p_mask)[c]; @@ -95,14 +94,13 @@ void CopyInterface(std::vector &interface_arr, Json{Boolean{false}}}; out["data"] = Array(std::move(j_data)); - out["shape"] = Array(std::vector{Json(Integer(interface.num_rows)), - Json(Integer(interface.num_cols))}); + out["shape"] = Array(std::vector{Json(Integer(interface.Shape(0)))}); if (interface.valid.Data()) { CopyColumnMask(interface, columns, kind, c, &mask, &out, stream); } out["typestr"] = String(" *out, cudaStream_t auto &j_interface = *p_interface; CHECK_EQ(get(j_interface).size(), 1); auto object = get(get(j_interface)[0]); - ArrayInterface interface(object); - out->resize(interface.num_rows); + ArrayInterface<1> interface(object); + out->resize(interface.Shape(0)); size_t element_size = interface.ElementSize(); - size_t size = element_size * interface.num_rows; + size_t size = element_size * interface.n; dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size, cudaMemcpyDeviceToDevice, stream)); j_interface[0]["data"][0] = reinterpret_cast(RawPtr(*out)); @@ -285,11 +283,11 @@ class DataIteratorProxy { Json features = json_interface["features_str"]; auto json_columns = get(features); - std::vector interfaces; + std::vector> interfaces; // Stage the data for (auto &json_col : json_columns) { - auto column = ArrayInterface(get(json_col)); + auto column = ArrayInterface<1>(get(json_col)); interfaces.emplace_back(column); } Json::Dump(features, &interface_str); @@ -342,9 +340,9 @@ class DataIteratorProxy { // Data auto const &json_interface = host_columns_.at(it_)->interfaces; - std::vector in; + std::vector> in; for (auto interface : json_interface) { - auto column = ArrayInterface(get(interface)); + auto column = ArrayInterface<1>(get(interface)); in.emplace_back(column); } std::vector out; diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 285b2e840f59..c739f9267af5 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable, Optional, List +from typing import Any, Tuple, Callable, Optional, List, Union import numpy as np @@ -138,14 +138,14 @@ def _is_numpy_array(data): return isinstance(data, (np.ndarray, np.matrix)) -def _ensure_np_dtype(data, dtype): +def _ensure_np_dtype(data, dtype) -> Tuple[np.ndarray, np.dtype]: if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]: data = data.astype(np.float32, copy=False) dtype = np.float32 return data, dtype -def _maybe_np_slice(data, dtype): +def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray: '''Handle numpy slice. This can be removed if we use __array_interface__. ''' try: @@ -852,23 +852,17 @@ def _validate_meta_shape(data: Any, name: str) -> None: def _meta_from_numpy( - data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p + data: np.ndarray, + field: str, + dtype: Optional[Union[np.dtype, str]], + handle: ctypes.c_void_p, ) -> None: - data = _maybe_np_slice(data, dtype) + data, dtype = _ensure_np_dtype(data, dtype) interface = data.__array_interface__ - assert interface.get('mask', None) is None, 'Masked array is not supported' - size = data.size - - c_type = _to_data_type(str(data.dtype), field) - ptr = interface['data'][0] - ptr = ctypes.c_void_p(ptr) - _check_call(_LIB.XGDMatrixSetDenseInfo( - handle, - c_str(field), - ptr, - c_bst_ulong(size), - c_type - )) + if interface.get("mask", None) is not None: + raise ValueError("Masked array is not supported.") + interface_str = _array_interface(data) + _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface_str)) def _meta_from_list(data, field, dtype, handle): @@ -911,7 +905,9 @@ def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p): _meta_from_numpy(data, field, dtype, handle) -def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): +def dispatch_meta_backend( + matrix: DMatrix, data, name: str, dtype: Optional[Union[str, np.dtype]] = None +): '''Dispatch for meta info.''' handle = matrix.handle assert handle is not None diff --git a/src/common/common.cu b/src/common/common.cu index d30fbc0aeb88..4636a4cdcb7c 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -12,7 +12,8 @@ int AllVisibleGPUs() { // When compiled with CUDA but running on CPU only device, // cudaGetDeviceCount will fail. dh::safe_cuda(cudaGetDeviceCount(&n_visgpus)); - } catch(const dmlc::Error &except) { + } catch (const dmlc::Error &) { + cudaGetLastError(); // reset error. return 0; } return n_visgpus; diff --git a/src/data/adapter.h b/src/data/adapter.h index 27da8c6e3b36..c4253b7e7a0e 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -254,20 +254,20 @@ class ArrayAdapterBatch : public detail::NoMetaInfo { static constexpr bool kIsRowMajor = true; private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; class Line { - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; size_t ridx_; public: - Line(ArrayInterface array_interface, size_t ridx) + Line(ArrayInterface<2> array_interface, size_t ridx) : array_interface_{std::move(array_interface)}, ridx_{ridx} {} - size_t Size() const { return array_interface_.num_cols; } + size_t Size() const { return array_interface_.Shape(1); } COOTuple GetElement(size_t idx) const { - return {ridx_, idx, array_interface_.GetElement(ridx_, idx)}; + return {ridx_, idx, array_interface_(ridx_, idx)}; } }; @@ -277,11 +277,11 @@ class ArrayAdapterBatch : public detail::NoMetaInfo { return Line{array_interface_, idx}; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumCols() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumCols() const { return array_interface_.Shape(1); } size_t Size() const { return this->NumRows(); } - explicit ArrayAdapterBatch(ArrayInterface array_interface) + explicit ArrayAdapterBatch(ArrayInterface<2> array_interface) : array_interface_{std::move(array_interface)} {} }; @@ -294,43 +294,42 @@ class ArrayAdapter : public detail::SingleBatchDataIter { public: explicit ArrayAdapter(StringView array_interface) { auto j = Json::Load(array_interface); - array_interface_ = ArrayInterface(get(j)); + array_interface_ = ArrayInterface<2>(get(j)); batch_ = ArrayAdapterBatch{array_interface_}; } ArrayAdapterBatch const& Value() const override { return batch_; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumColumns() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumColumns() const { return array_interface_.Shape(1); } private: ArrayAdapterBatch batch_; - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; }; class CSRArrayAdapterBatch : public detail::NoMetaInfo { - ArrayInterface indptr_; - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indptr_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; bst_feature_t n_features_; class Line { - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; size_t ridx_; size_t offset_; public: - Line(ArrayInterface indices, ArrayInterface values, size_t ridx, + Line(ArrayInterface<1> indices, ArrayInterface<1> values, size_t ridx, size_t offset) : indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx}, offset_{offset} {} COOTuple GetElement(size_t idx) const { - return {ridx_, indices_.GetElement(offset_ + idx, 0), - values_.GetElement(offset_ + idx, 0)}; + return {ridx_, indices_.operator()(offset_ + idx), values_(offset_ + idx)}; } size_t Size() const { - return values_.num_rows * values_.num_cols; + return values_.Shape(0); } }; @@ -339,17 +338,16 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { public: CSRArrayAdapterBatch() = default; - CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices, - ArrayInterface values, bst_feature_t n_features) - : indptr_{std::move(indptr)}, indices_{std::move(indices)}, - values_{std::move(values)}, n_features_{n_features} { - indptr_.AsColumnVector(); - values_.AsColumnVector(); - indices_.AsColumnVector(); + CSRArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices, + ArrayInterface<1> values, bst_feature_t n_features) + : indptr_{std::move(indptr)}, + indices_{std::move(indices)}, + values_{std::move(values)}, + n_features_{n_features} { } size_t NumRows() const { - size_t size = indptr_.num_rows * indptr_.num_cols; + size_t size = indptr_.Shape(0); size = size == 0 ? 0 : size - 1; return size; } @@ -357,19 +355,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { size_t Size() const { return this->NumRows(); } Line const GetLine(size_t idx) const { - auto begin_offset = indptr_.GetElement(idx, 0); - auto end_offset = indptr_.GetElement(idx + 1, 0); + auto begin_no_stride = indptr_.operator()(idx); + auto end_no_stride = indptr_.operator()(idx + 1); auto indices = indices_; auto values = values_; + // Slice indices and values, stride remains unchanged since this is slicing by + // specific index. + auto offset = indices.strides[0] * begin_no_stride; - values.num_cols = end_offset - begin_offset; - values.num_rows = 1; + indices.shape[0] = end_no_stride - begin_no_stride; + values.shape[0] = end_no_stride - begin_no_stride; - indices.num_cols = values.num_cols; - indices.num_rows = values.num_rows; - - return Line{indices, values, idx, begin_offset}; + return Line{indices, values, idx, offset}; } }; @@ -391,7 +389,7 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter return batch_; } size_t NumRows() const { - size_t size = indptr_.num_cols * indptr_.num_rows; + size_t size = indptr_.Shape(0); size = size == 0 ? 0 : size - 1; return size; } @@ -399,9 +397,9 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter private: CSRArrayAdapterBatch batch_; - ArrayInterface indptr_; - ArrayInterface indices_; - ArrayInterface values_; + ArrayInterface<1> indptr_; + ArrayInterface<1> indices_; + ArrayInterface<1> values_; size_t num_cols_; }; diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index def4de195523..08dbfaefb647 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -7,15 +7,50 @@ namespace xgboost { void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { switch (stream) { - case 0: - LOG(FATAL) << "Invalid stream ID in array interface: " << stream; - case 1: - // default legacy stream - break; - case 2: - // default per-thread stream - default: - dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); + case 0: + /** + * disallowed by the `__cuda_array_interface__`. Quote: + * + * This is disallowed as it would be ambiguous between None and the default + * stream, and also between the legacy and per-thread default streams. Any use + * case where 0 might be given should either use None, 1, or 2 instead for + * clarity. + */ + LOG(FATAL) << "Invalid stream ID in array interface: " << stream; + case 1: + // default legacy stream + break; + case 2: + // default per-thread stream + default: + dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast(stream))); + } +} + +bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { + if (!ptr) { + return false; + } + cudaPointerAttributes attr; + auto err = cudaPointerGetAttributes(&attr, ptr); + // reset error + CHECK_EQ(err, cudaGetLastError()); + if (err == cudaErrorInvalidValue) { + // CUDA < 11 + return false; + } else if (err == cudaSuccess) { + // CUDA >= 11 + switch (attr.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + return false; + default: + return true; + } + return true; + } else { + // other errors, no `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. + return false; } } } // namespace xgboost diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 6524f4512407..25d2361a9182 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -13,24 +13,23 @@ #include #include +#include "../common/bitfield.h" +#include "../common/common.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "xgboost/logging.h" #include "xgboost/span.h" -#include "../common/bitfield.h" -#include "../common/common.h" namespace xgboost { // Common errors in parsing columnar format. struct ArrayInterfaceErrors { - static char const* Contigious() { - return "Memory should be contigious."; - } - static char const* TypestrFormat() { + static char const *Contiguous() { return "Memory should be contiguous."; } + static char const *TypestrFormat() { return "`typestr' should be of format ."; } - static char const* Dimension(int32_t d) { + static char const *Dimension(int32_t d) { static std::string str; str.clear(); str += "Only "; @@ -38,11 +37,11 @@ struct ArrayInterfaceErrors { str += " dimensional array is valid."; return str.c_str(); } - static char const* Version() { - return "Only version <= 3 of " - "`__cuda_array_interface__/__array_interface__' are supported."; + static char const *Version() { + return "Only version <= 3 of `__cuda_array_interface__' and `__array_interface__' are " + "supported."; } - static char const* OfType(std::string const& type) { + static char const *OfType(std::string const &type) { static std::string str; str.clear(); str += " should be of "; @@ -96,38 +95,25 @@ struct ArrayInterfaceErrors { // object and turn it into an array (for cupy and numba). class ArrayInterfaceHandler { public: - template - static constexpr char TypeChar() { - return - (std::is_floating_point::value ? 'f' : - (std::is_integral::value ? - (std::is_signed::value ? 'i' : 'u') : '\0')); - } + enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; template - static PtrType GetPtrFromArrayData(std::map const& obj) { + static PtrType GetPtrFromArrayData(std::map const &obj) { if (obj.find("data") == obj.cend()) { LOG(FATAL) << "Empty data passed in."; } - auto p_data = reinterpret_cast(static_cast( - get( - get( - obj.at("data")) - .at(0)))); + auto p_data = reinterpret_cast( + static_cast(get(get(obj.at("data")).at(0)))); return p_data; } - static void Validate(std::map const& array) { + static void Validate(std::map const &array) { auto version_it = array.find("version"); if (version_it == array.cend()) { LOG(FATAL) << "Missing `version' field for array interface"; } - auto stream_it = array.find("stream"); - if (stream_it != array.cend() && !IsA(stream_it->second)) { - // is cuda, check the version. - if (get(version_it->second) > 3) { - LOG(FATAL) << ArrayInterfaceErrors::Version(); - } + if (get(version_it->second) > 3) { + LOG(FATAL) << ArrayInterfaceErrors::Version(); } if (array.find("typestr") == array.cend()) { @@ -149,12 +135,12 @@ class ArrayInterfaceHandler { // Mask object is also an array interface, but with different requirements. static size_t ExtractMask(std::map const &column, common::Span *p_out) { - auto& s_mask = *p_out; + auto &s_mask = *p_out; if (column.find("mask") != column.cend()) { - auto const& j_mask = get(column.at("mask")); + auto const &j_mask = get(column.at("mask")); Validate(j_mask); - auto p_mask = GetPtrFromArrayData(j_mask); + auto p_mask = GetPtrFromArrayData(j_mask); auto j_shape = get(j_mask.at("shape")); CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1); @@ -186,8 +172,8 @@ class ArrayInterfaceHandler { if (j_mask.find("strides") != j_mask.cend()) { auto strides = get(column.at("strides")); - CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); - CHECK_EQ(get(strides.at(0)), type_length) << ArrayInterfaceErrors::Contigious(); + CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); + CHECK_EQ(get(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous(); } s_mask = {p_mask, span_size}; @@ -195,96 +181,225 @@ class ArrayInterfaceHandler { } return 0; } - - static std::pair ExtractShape( - std::map const& column) { - auto j_shape = get(column.at("shape")); - auto typestr = get(column.at("typestr")); - if (j_shape.size() == 1) { - return {static_cast(get(j_shape.at(0))), 1}; - } else { - CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported."; - return {static_cast(get(j_shape.at(0))), - static_cast(get(j_shape.at(1)))}; + /** + * \brief Handle vector inputs. For higher dimension, we require strictly correct shape. + */ + template + static void HandleRowVector(std::vector const &shape, std::vector *p_out) { + auto &out = *p_out; + if (shape.size() == 2 && D == 1) { + auto m = shape[0]; + auto n = shape[1]; + CHECK(m == 1 || n == 1); + if (m == 1) { + // keep the number of columns + out[0] = out[1]; + out.resize(1); + } else if (n == 1) { + // keep the number of rows. + out.resize(1); + } + // when both m and n are 1, above logic keeps the column. + // when neither m nor n is 1, caller should throw an error about Dimension. } } - static void ExtractStride(std::map const &column, - size_t *stride_r, size_t *stride_c, size_t rows, - size_t cols, size_t itemsize) { - auto strides_it = column.find("strides"); - if (strides_it == column.cend() || IsA(strides_it->second)) { - // default strides - *stride_r = cols; - *stride_c = 1; - } else { - // strides specified by the array interface - auto const &j_strides = get(strides_it->second); - CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2); - *stride_r = get(j_strides[0]) / itemsize; - size_t n = 1; - if (j_strides.size() == 2) { - n = get(j_strides[1]) / itemsize; - } - *stride_c = n; + template + static void ExtractShape(std::map const &array, size_t (&out_shape)[D]) { + auto const &j_shape = get(array.at("shape")); + std::vector shape_arr(j_shape.size(), 0); + std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(), + [](Json in) { return get(in); }); + // handle column vector vs. row vector + HandleRowVector(shape_arr, &shape_arr); + // Copy shape. + size_t i; + for (i = 0; i < shape_arr.size(); ++i) { + CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D); + out_shape[i] = shape_arr[i]; } + // Fill the remaining dimensions + std::fill(out_shape + i, out_shape + D, 1); + } - auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols); - CHECK(valid) << "Invalid strides in array." - << " strides: (" << (*stride_r) << "," << (*stride_c) - << "), shape: (" << rows << ", " << cols << ")"; + template + static bool ExtractStride(std::map const &array, size_t itemsize, + size_t (&shape)[D], size_t (&stride)[D]) { + auto strides_it = array.find("strides"); + // No stride is provided + if (strides_it == array.cend() || IsA(strides_it->second)) { + // No stride is provided, we can calculate it from shape. + linalg::detail::CalcStride(shape, stride); + // Quote: + // + // strides: Either None to indicate a C-style contiguous array or a Tuple of + // strides which provides the number of bytes + return true; + } + // Get shape, we need to make changes to handle row vector, so some duplicated code + // from `ExtractShape` for copying out the shape. + auto const &j_shape = get(array.at("shape")); + std::vector shape_arr(j_shape.size(), 0); + std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(), + [](Json in) { return get(in); }); + // Get stride + auto const &j_strides = get(strides_it->second); + CHECK_EQ(j_strides.size(), j_shape.size()) << "stride and shape don't match."; + std::vector stride_arr(j_strides.size(), 0); + std::transform(j_strides.cbegin(), j_strides.cend(), stride_arr.begin(), + [](Json in) { return get(in); }); + + // Handle column vector vs. row vector + HandleRowVector(shape_arr, &stride_arr); + size_t i; + for (i = 0; i < stride_arr.size(); ++i) { + // If one of the dim has shape 0 then total size is 0, stride is meaningless, but we + // set it to 0 here just to be consistent + CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D); + // We use number of items instead of number of bytes + stride[i] = stride_arr[i] / itemsize; + } + std::fill(stride + i, stride + D, 1); + // If the stride can be calculated from shape then it's contiguous. + size_t stride_tmp[D]; + linalg::detail::CalcStride(shape, stride_tmp); + return std::equal(stride_tmp, stride_tmp + D, stride); } - static void* ExtractData(std::map const &column, - std::pair shape) { - Validate(column); - void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); + static void *ExtractData(std::map const &array, size_t size) { + Validate(array); + void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData(array); if (!p_data) { - CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape."; + CHECK_EQ(size, 0) << "Empty data with non-zero shape."; } return p_data; } - + /** + * \brief Whether the ptr is allocated by CUDA. + */ + static bool IsCudaPtr(void const *ptr); + /** + * \brief Sync the CUDA stream. + */ static void SyncCudaStream(int64_t stream); }; +/** + * Dispatch compile time type to runtime type. + */ +template +struct ToDType; +// float +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF8; +}; +template +struct ToDType::value && sizeof(long double) == 16>> { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF16; +}; +// uint +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU1; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU2; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU8; +}; +// int +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI1; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI2; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI4; +}; +template <> +struct ToDType { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI8; +}; + #if !defined(XGBOOST_USE_CUDA) -inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { - common::AssertGPUSupport(); -} +inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { common::AssertGPUSupport(); } +inline bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { return false; } #endif // !defined(XGBOOST_USE_CUDA) // A view over __array_interface__ +/** + * \brief A type erased view over __array_interface__ protocol defined by numpy + * + * numpy. + * + * \tparam D The number of maximum dimension. + + * User input array must have dim <= D for all non-trivial dimensions. During + * construction, the ctor can automatically remove those trivial dimensions. + * + * \tparam allow_mask Whether masked array is accepted. + * + * Currently this only supported for 1-dim vector, which is used by cuDF column + * (apache arrow format). For general masked array, as the time of writting, only + * numpy has the proper support even though it's in the __cuda_array_interface__ + * protocol defined by numba. + */ +template class ArrayInterface { - void Initialize(std::map const &array, - bool allow_mask = true) { + static_assert(D > 0, "Invalid dimension for array interface."); + + /** + * \brief Initialize the object, by extracting shape, stride and type. + * + * The function also perform some basic validation for input array. Lastly it will + * also remove trivial dimensions like converting a matrix with shape (n_samples, 1) + * to a vector of size n_samples. For for inputs like weights, this should be a 1 + * dimension column vector even though user might provide a matrix. + */ + void Initialize(std::map const &array) { ArrayInterfaceHandler::Validate(array); + auto typestr = get(array.at("typestr")); this->AssignType(StringView{typestr}); + ArrayInterfaceHandler::ExtractShape(array, shape); + size_t itemsize = typestr[2] - '0'; + is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides); + n = linalg::detail::CalcSize(shape); - std::tie(num_rows, num_cols) = ArrayInterfaceHandler::ExtractShape(array); - data = ArrayInterfaceHandler::ExtractData( - array, std::make_pair(num_rows, num_cols)); + data = ArrayInterfaceHandler::ExtractData(array, n); if (allow_mask) { + CHECK(D == 1) << "Masked array is not supported."; common::Span s_mask; size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask); valid = RBitField8(s_mask); if (s_mask.data()) { - CHECK_EQ(n_bits, num_rows) - << "Shape of bit mask doesn't match data shape. " - << "XGBoost doesn't support internal broadcasting."; + CHECK_EQ(n_bits, n) << "Shape of bit mask doesn't match data shape. " + << "XGBoost doesn't support internal broadcasting."; } } else { - CHECK(array.find("mask") == array.cend()) - << "Masked array is not yet supported."; + CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported."; } - ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col, - num_rows, num_cols, typestr[2] - '0'); - auto stream_it = array.find("stream"); if (stream_it != array.cend() && !IsA(stream_it->second)) { int64_t stream = get(stream_it->second); @@ -292,151 +407,149 @@ class ArrayInterface { } } - public: - enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; - public: ArrayInterface() = default; - explicit ArrayInterface(std::string const &str, bool allow_mask = true) - : ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {} - - explicit ArrayInterface(std::map const &column, - bool allow_mask = true) { - this->Initialize(column, allow_mask); + explicit ArrayInterface(std::map const &array) { + this->Initialize(array); } - explicit ArrayInterface(StringView str, bool allow_mask = true) { - auto jinterface = Json::Load(str); - if (IsA(jinterface)) { - this->Initialize(get(jinterface), allow_mask); + explicit ArrayInterface(Json const &array) { + if (IsA(array)) { + this->Initialize(get(array)); return; } - if (IsA(jinterface)) { - CHECK_EQ(get(jinterface).size(), 1) + if (IsA(array)) { + CHECK_EQ(get(array).size(), 1) << "Column: " << ArrayInterfaceErrors::Dimension(1); - this->Initialize(get(get(jinterface)[0]), allow_mask); + this->Initialize(get(get(array)[0])); return; } } - void AsColumnVector() { - CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix."; - num_rows = std::max(num_rows, static_cast(num_cols)); - num_cols = 1; + explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {} - stride_row = std::max(stride_row, stride_col); - stride_col = 1; - } + explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {} void AssignType(StringView typestr) { - if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && - typestr[3] == '6') { - type = kF16; + using T = ArrayInterfaceHandler::Type; + if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') { + type = T::kF16; CHECK(sizeof(long double) == 16) << "128-bit floating point is not supported on current platform."; } else if (typestr[1] == 'f' && typestr[2] == '4') { - type = kF4; + type = T::kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { - type = kF8; + type = T::kF8; } else if (typestr[1] == 'i' && typestr[2] == '1') { - type = kI1; + type = T::kI1; } else if (typestr[1] == 'i' && typestr[2] == '2') { - type = kI2; + type = T::kI2; } else if (typestr[1] == 'i' && typestr[2] == '4') { - type = kI4; + type = T::kI4; } else if (typestr[1] == 'i' && typestr[2] == '8') { - type = kI8; + type = T::kI8; } else if (typestr[1] == 'u' && typestr[2] == '1') { - type = kU1; + type = T::kU1; } else if (typestr[1] == 'u' && typestr[2] == '2') { - type = kU2; + type = T::kU2; } else if (typestr[1] == 'u' && typestr[2] == '4') { - type = kU4; + type = T::kU4; } else if (typestr[1] == 'u' && typestr[2] == '8') { - type = kU8; + type = T::kU8; } else { LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr); return; } } + XGBOOST_DEVICE size_t Shape(size_t i) const { return shape[i]; } + XGBOOST_DEVICE size_t Stride(size_t i) const { return strides[i]; } + template - XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const { + XGBOOST_HOST_DEV_INLINE constexpr decltype(auto) DispatchCall(Fn func) const { + using T = ArrayInterfaceHandler::Type; switch (type) { - case kF4: - return func(reinterpret_cast(data)); - case kF8: - return func(reinterpret_cast(data)); + case T::kF4: + return func(reinterpret_cast(data)); + case T::kF8: + return func(reinterpret_cast(data)); #ifdef __CUDA_ARCH__ - case kF16: { - // CUDA device code doesn't support long double. - SPAN_CHECK(false); - return func(reinterpret_cast(data)); - } + case T::kF16: { + // CUDA device code doesn't support long double. + SPAN_CHECK(false); + return func(reinterpret_cast(data)); + } #else - case kF16: - return func(reinterpret_cast(data)); + case T::kF16: + return func(reinterpret_cast(data)); #endif - case kI1: - return func(reinterpret_cast(data)); - case kI2: - return func(reinterpret_cast(data)); - case kI4: - return func(reinterpret_cast(data)); - case kI8: - return func(reinterpret_cast(data)); - case kU1: - return func(reinterpret_cast(data)); - case kU2: - return func(reinterpret_cast(data)); - case kU4: - return func(reinterpret_cast(data)); - case kU8: - return func(reinterpret_cast(data)); + case T::kI1: + return func(reinterpret_cast(data)); + case T::kI2: + return func(reinterpret_cast(data)); + case T::kI4: + return func(reinterpret_cast(data)); + case T::kI8: + return func(reinterpret_cast(data)); + case T::kU1: + return func(reinterpret_cast(data)); + case T::kU2: + return func(reinterpret_cast(data)); + case T::kU4: + return func(reinterpret_cast(data)); + case T::kU8: + return func(reinterpret_cast(data)); } SPAN_CHECK(false); return func(reinterpret_cast(data)); } - XGBOOST_DEVICE size_t ElementSize() { - return this->DispatchCall([](auto* p_values) { - return sizeof(std::remove_pointer_t); - }); + XGBOOST_DEVICE size_t constexpr ElementSize() { + return this->DispatchCall( + [](auto *p_values) { return sizeof(std::remove_pointer_t); }); } - template - XGBOOST_DEVICE T GetElement(size_t r, size_t c) const { - return this->DispatchCall( - [=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; }); + template + XGBOOST_DEVICE T operator()(Index &&...index) const { + static_assert(sizeof...(index) <= D, "Invalid index."); + return this->DispatchCall([=](auto const *p_values) -> T { + size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...); + return static_cast(p_values[offset]); + }); } + // Used only by columnar format. RBitField8 valid; - bst_row_t num_rows; - bst_feature_t num_cols; - size_t stride_row{0}; - size_t stride_col{0}; - void* data; - Type type; + // Array stride + size_t strides[D]{0}; + // Array shape + size_t shape[D]{0}; + // Type earsed pointer referencing the data. + void *data; + // Total number of items + size_t n; + // Whether the memory is c-contiguous + bool is_contiguous {false}; + // RTTI + ArrayInterfaceHandler::Type type; }; -template std::string MakeArrayInterface(T const *data, size_t n) { - Json arr{Object{}}; - arr["data"] = Array(std::vector{ - Json{Integer{reinterpret_cast(data)}}, Json{Boolean{false}}}); - arr["shape"] = Array{std::vector{Json{Integer{n}}, Json{Integer{1}}}}; - std::string typestr; - if (DMLC_LITTLE_ENDIAN) { - typestr.push_back('<'); - } else { - typestr.push_back('>'); +/** + * \brief Helper for type casting. + */ +template +struct TypedIndex { + ArrayInterface const &array; + template + XGBOOST_DEVICE T operator()(I &&...ind) { + static_assert(sizeof...(ind) <= D, "Invalid index."); + return array.template operator()(ind...); } - typestr.push_back(ArrayInterfaceHandler::TypeChar()); - typestr += std::to_string(sizeof(T)); - arr["typestr"] = typestr; - arr["version"] = 3; - std::string str; - Json::Dump(arr, &str); - return str; +}; + +template +inline void CheckArrayInterface(StringView key, ArrayInterface const &array) { + CHECK(!array.valid.Data()) << "Meta info " << key << " should be dense, found validity mask"; } } // namespace xgboost #endif // XGBOOST_DATA_ARRAY_INTERFACE_H_ diff --git a/src/data/data.cc b/src/data/data.cc index acf6e47b9d21..d9f489103f22 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,8 +1,9 @@ /*! - * Copyright 2015-2020 by Contributors + * Copyright 2015-2021 by Contributors * \file data.cc */ #include +#include #include #include "dmlc/io.h" @@ -12,10 +13,13 @@ #include "xgboost/logging.h" #include "xgboost/version_config.h" #include "xgboost/learner.h" +#include "xgboost/string_view.h" + #include "sparse_page_writer.h" #include "simple_dmatrix.h" #include "../common/io.h" +#include "../common/linalg_op.h" #include "../common/math.h" #include "../common/version.h" #include "../common/group_data.h" @@ -24,6 +28,7 @@ #include "../data/iterative_device_dmatrix.h" #include "file_iterator.h" +#include "validation.h" #include "./sparse_page_source.h" #include "./sparse_page_dmatrix.h" @@ -65,10 +70,22 @@ void SaveVectorField(dmlc::Stream* strm, const std::string& name, SaveVectorField(strm, name, type, shape, field.ConstHostVector()); } +template +void SaveTensorField(dmlc::Stream* strm, const std::string& name, xgboost::DataType type, + const xgboost::linalg::Tensor& field) { + strm->Write(name); + strm->Write(static_cast(type)); + strm->Write(false); // is_scalar=False + for (size_t i = 0; i < D; ++i) { + strm->Write(field.Shape(i)); + } + strm->Write(field.Data()->HostVector()); +} + template void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, T* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -90,7 +107,7 @@ void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name, template void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, xgboost::DataType expected_type, std::vector* field) { - const std::string invalid {"MetaInfo: Invalid format. "}; + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; std::string name; xgboost::DataType type; bool is_scalar; @@ -123,6 +140,33 @@ void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name, LoadVectorField(strm, expected_name, expected_type, &field->HostVector()); } +template +void LoadTensorField(dmlc::Stream* strm, std::string const& expected_name, + xgboost::DataType expected_type, xgboost::linalg::Tensor* p_out) { + const std::string invalid{"MetaInfo: Invalid format for " + expected_name}; + std::string name; + xgboost::DataType type; + bool is_scalar; + CHECK(strm->Read(&name)) << invalid; + CHECK_EQ(name, expected_name) << invalid << " Expected field: " << expected_name + << ", got: " << name; + uint8_t type_val; + CHECK(strm->Read(&type_val)) << invalid; + type = static_cast(type_val); + CHECK(type == expected_type) << invalid + << "Expected field of type: " << static_cast(expected_type) + << ", " + << "got field type: " << static_cast(type); + CHECK(strm->Read(&is_scalar)) << invalid; + CHECK(!is_scalar) << invalid << "Expected field " << expected_name + << " to be a tensor; got a scalar"; + std::array shape; + for (size_t i = 0; i < D; ++i) { + CHECK(strm->Read(&(shape[i]))); + } + auto& field = p_out->Data()->HostVector(); + CHECK(strm->Read(&field)) << invalid; +} } // anonymous namespace namespace xgboost { @@ -135,25 +179,26 @@ void MetaInfo::Clear() { labels_.HostVector().clear(); group_ptr_.clear(); weights_.HostVector().clear(); - base_margin_.HostVector().clear(); + base_margin_ = decltype(base_margin_){}; } /* * Binary serialization format for MetaInfo: * - * | name | type | is_scalar | num_row | num_col | value | - * |--------------------+----------+-----------+---------+---------+-------------------------| - * | num_row | kUInt64 | True | NA | NA | ${num_row_} | - * | num_col | kUInt64 | True | NA | NA | ${num_col_} | - * | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} | - * | labels | kFloat32 | False | ${size} | 1 | ${labels_} | - * | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} | - * | weights | kFloat32 | False | ${size} | 1 | ${weights_} | - * | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} | - * | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound_} | - * | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound_} | - * | feature_names | kStr | False | ${size} | 1 | ${feature_names} | - * | feature_types | kStr | False | ${size} | 1 | ${feature_types} | + * | name | type | is_scalar | num_row | num_col | dim3 | value | + * |--------------------+----------+-----------+-------------+-------------+-------------+------------------------| + * | num_row | kUInt64 | True | NA | NA | NA | ${num_row_} | + * | num_col | kUInt64 | True | NA | NA | NA | ${num_col_} | + * | num_nonzero | kUInt64 | True | NA | NA | NA | ${num_nonzero_} | + * | labels | kFloat32 | False | ${size} | 1 | NA | ${labels_} | + * | group_ptr | kUInt32 | False | ${size} | 1 | NA | ${group_ptr_} | + * | weights | kFloat32 | False | ${size} | 1 | NA | ${weights_} | + * | base_margin | kFloat32 | False | ${Shape(0)} | ${Shape(1)} | ${Shape(2)} | ${base_margin_} | + * | labels_lower_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_lower_bound_} | + * | labels_upper_bound | kFloat32 | False | ${size} | 1 | NA | ${labels_upper_bound_} | + * | feature_names | kStr | False | ${size} | 1 | NA | ${feature_names} | + * | feature_types | kStr | False | ${size} | 1 | NA | ${feature_types} | + * | feature_types | kFloat32 | False | ${size} | 1 | NA | ${feature_weights} | * * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. * Also notice the difference between the saved name and the name used in `SetInfo': @@ -174,8 +219,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {group_ptr_.size(), 1}, group_ptr_); ++field_cnt; SaveVectorField(fo, u8"weights", DataType::kFloat32, {weights_.Size(), 1}, weights_); ++field_cnt; - SaveVectorField(fo, u8"base_margin", DataType::kFloat32, - {base_margin_.Size(), 1}, base_margin_); ++field_cnt; + SaveTensorField(fo, u8"base_margin", DataType::kFloat32, base_margin_); ++field_cnt; SaveVectorField(fo, u8"labels_lower_bound", DataType::kFloat32, {labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt; SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, @@ -185,6 +229,9 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { {feature_names.size(), 1}, feature_names); ++field_cnt; SaveVectorField(fo, u8"feature_types", DataType::kStr, {feature_type_names.size(), 1}, feature_type_names); ++field_cnt; + SaveVectorField(fo, u8"feature_weights", DataType::kFloat32, {feature_weights.Size(), 1}, + feature_weights); + ++field_cnt; CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } @@ -213,10 +260,14 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { auto major = std::get<0>(version); // MetaInfo is saved in `SparsePageSource'. So the version in MetaInfo represents the // version of DMatrix. - CHECK_EQ(major, 1) << "Binary DMatrix generated by XGBoost: " - << Version::String(version) << " is no longer supported. " - << "Please process and save your data in current version: " - << Version::String(Version::Self()) << " again."; + std::stringstream msg; + msg << "Binary DMatrix generated by XGBoost: " << Version::String(version) + << " is no longer supported. " + << "Please process and save your data in current version: " + << Version::String(Version::Self()) << " again."; + CHECK_EQ(major, 1) << msg.str(); + auto minor = std::get<1>(version); + CHECK_GE(minor, 6) << msg.str(); const uint64_t expected_num_field = kNumField; uint64_t num_field { 0 }; @@ -243,12 +294,13 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"labels", DataType::kFloat32, &labels_); LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_); LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_); - LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); + LoadTensorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_); LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_); LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names); LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names); + LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); LoadFeatureType(feature_type_names, &feature_types.HostVector()); } @@ -291,10 +343,13 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { if (this->base_margin_.Size() != this->num_row_) { CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0) << "Incorrect size of base margin vector."; - size_t stride = this->base_margin_.Size() / this->num_row_; - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride); + auto margin = this->base_margin_.View(this->base_margin_.Data()->DeviceIdx()); + out.base_margin_.Reshape(ridxs.size(), margin.Shape()[1], margin.Shape()[2]); + size_t stride = margin.Stride(0); + out.base_margin_.Data()->HostVector() = + Gather(this->base_margin_.Data()->HostVector(), ridxs, stride); } else { - out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); + out.base_margin_.Data()->HostVector() = Gather(this->base_margin_.Data()->HostVector(), ridxs); } out.feature_weights.Resize(this->feature_weights.Size()); @@ -337,116 +392,179 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, return true; } -void ValidateQueryGroup(std::vector const &group_ptr_) { - bool valid_query_group = true; - for (size_t i = 1; i < group_ptr_.size(); ++i) { - valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1]; - if (!valid_query_group) { - break; - } +namespace { +template +void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { + ArrayInterface array{arr_interface}; + if (array.n == 0) { + return; + } + CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + if (array.is_contiguous && array.type == ToDType::kType) { + // Handle contigious + p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + // set shape + std::copy(array.shape, array.shape + D, shape.data()); + // set data + data->Resize(array.n); + std::memcpy(data->HostPointer(), array.data, array.n * sizeof(T)); + }); + return; } - CHECK(valid_query_group) << "Invalid group structure."; + p_out->Reshape(array.shape); + auto t = p_out->View(GenericParameter::kCpuId); + CHECK(t.Contiguous()); + // FIXME(jiamingy): Remove the use of this default thread. + linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) { + return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); + }); } +} // namespace + +void MetaInfo::SetInfo(StringView key, StringView interface_str) { + Json j_interface = Json::Load(interface_str); + bool is_cuda{false}; + if (IsA(j_interface)) { + auto const& array = get(j_interface); + CHECK_GE(array.size(), 0) << "Invalid " << key + << ", must have at least 1 column even if it's empty."; + auto const& first = get(array.front()); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } else { + auto const& first = get(j_interface); + auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); + is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + } -// macro to dispatch according to specified pointer types -#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ - switch (dtype) { \ - case xgboost::DataType::kFloat32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kDouble: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt32: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - case xgboost::DataType::kUInt64: { \ - auto cast_ptr = reinterpret_cast(old_ptr); proc; break; \ - } \ - default: LOG(FATAL) << "Unknown data type" << static_cast(dtype); \ - } \ + if (is_cuda) { + this->SetInfoFromCUDA(key, j_interface); + } else { + this->SetInfoFromHost(key, j_interface); + } +} -void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - if (!std::strcmp(key, "label")) { - auto& labels = labels_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - auto valid = std::none_of(labels.cbegin(), labels.cend(), [](auto y) { - return std::isnan(y) || std::isinf(y); - }); - CHECK(valid) << "Label contains NaN, infinity or a value too large."; - } else if (!std::strcmp(key, "weight")) { - auto& weights = weights_.HostVector(); - weights.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, weights.begin())); - auto valid = std::none_of(weights.cbegin(), weights.cend(), [](float w) { - return w < 0 || std::isinf(w) || std::isnan(w); - }); - CHECK(valid) << "Weights must be positive values."; - } else if (!std::strcmp(key, "base_margin")) { - auto& base_margin = base_margin_.HostVector(); - base_margin.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, base_margin.begin())); - } else if (!std::strcmp(key, "group")) { - group_ptr_.clear(); group_ptr_.resize(num + 1, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1)); - group_ptr_[0] = 0; - for (size_t i = 1; i < group_ptr_.size(); ++i) { - group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; +void MetaInfo::SetInfoFromHost(StringView key, Json arr) { + // multi-dim float info + if (key == "base_margin") { + CopyTensorInfoImpl<3>(arr, &this->base_margin_); + // FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of + // input shape. This issue is CPU only since CUDA uses array interface from day 1. + // + // Python binding always understand the shape, so this condition should not occur for + // it. + if (this->num_row_ != 0 && this->base_margin_.Shape(0) != this->num_row_) { + // API functions that don't use array interface don't understand shape. + CHECK(this->base_margin_.Size() % this->num_row_ == 0) << "Incorrect size for base margin."; + size_t n_groups = this->base_margin_.Size() / this->num_row_; + this->base_margin_.Reshape(this->num_row_, n_groups); } - ValidateQueryGroup(group_ptr_); - } else if (!std::strcmp(key, "qid")) { - std::vector query_ids(num, 0); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, query_ids.begin())); + return; + } + // uint info + if (key == "group") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); + auto const& h_groups = t.Data()->HostVector(); + group_ptr_.clear(); + group_ptr_.resize(h_groups.size() + 1, 0); + group_ptr_[0] = 0; + std::partial_sum(h_groups.cbegin(), h_groups.cend(), group_ptr_.begin() + 1); + data::ValidateQueryGroup(group_ptr_); + return; + } else if (key == "qid") { + linalg::Tensor t; + CopyTensorInfoImpl(arr, &t); bool non_dec = true; + auto const& query_ids = t.Data()->HostVector(); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] < query_ids[i-1]) { + if (query_ids[i] < query_ids[i - 1]) { non_dec = false; break; } } CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data."; - group_ptr_.clear(); group_ptr_.push_back(0); + group_ptr_.clear(); + group_ptr_.push_back(0); for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] != query_ids[i-1]) { + if (query_ids[i] != query_ids[i - 1]) { group_ptr_.push_back(i); } } if (group_ptr_.back() != query_ids.size()) { group_ptr_.push_back(query_ids.size()); } - } else if (!std::strcmp(key, "label_lower_bound")) { - auto& labels = labels_lower_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "label_upper_bound")) { - auto& labels = labels_upper_bound_.HostVector(); - labels.resize(num); - DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels.begin())); - } else if (!std::strcmp(key, "feature_weights")) { - auto &h_feature_weights = feature_weights.HostVector(); - h_feature_weights.resize(num); - DISPATCH_CONST_PTR( - dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin())); + data::ValidateQueryGroup(group_ptr_); + return; + } + // float info + linalg::Tensor t; + CopyTensorInfoImpl<1>(arr, &t); + if (key == "label") { + this->labels_ = std::move(*t.Data()); + auto const& h_labels = labels_.ConstHostVector(); + auto valid = std::none_of(h_labels.cbegin(), h_labels.cend(), data::LabelsCheck{}); + CHECK(valid) << "Label contains NaN, infinity or a value too large."; + } else if (key == "weight") { + this->weights_ = std::move(*t.Data()); + auto const& h_weights = this->weights_.ConstHostVector(); + auto valid = std::none_of(h_weights.cbegin(), h_weights.cend(), + [](float w) { return w < 0 || std::isinf(w) || std::isnan(w); }); + CHECK(valid) << "Weights must be positive values."; + } else if (key == "label_lower_bound") { + this->labels_lower_bound_ = std::move(*t.Data()); + } else if (key == "label_upper_bound") { + this->labels_upper_bound_ = std::move(*t.Data()); + } else if (key == "feature_weights") { + this->feature_weights = std::move(*t.Data()); + auto const& h_feature_weights = feature_weights.ConstHostVector(); bool valid = - std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { return w < 0; }); + std::none_of(h_feature_weights.cbegin(), h_feature_weights.cend(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; } else { LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } -void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, - const void **out_dptr) const { +void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { + auto proc = [&](auto cast_d_ptr) { + using T = std::remove_pointer_t; + auto t = + linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, GenericParameter::kCpuId); + CHECK(t.Contiguous()); + Json interface { t.ArrayInterface() }; + assert(ArrayInterface<1>{interface}.is_contiguous); + return interface; + }; + // Legacy code using XGBoost dtype, which is a small subset of array interface types. + switch (dtype) { + case xgboost::DataType::kFloat32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kDouble: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt32: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt64: { + auto cast_ptr = reinterpret_cast(dptr); + this->SetInfoFromHost(key, proc(cast_ptr)); + break; + } + default: + LOG(FATAL) << "Unknown data type" << static_cast(dtype); + } +} + +void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, + const void** out_dptr) const { if (dtype == DataType::kFloat32) { const std::vector* vec = nullptr; if (!std::strcmp(key, "label")) { @@ -454,7 +572,7 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, } else if (!std::strcmp(key, "weight")) { vec = &this->weights_.HostVector(); } else if (!std::strcmp(key, "base_margin")) { - vec = &this->base_margin_.HostVector(); + vec = &this->base_margin_.Data()->HostVector(); } else if (!std::strcmp(key, "label_lower_bound")) { vec = &this->labels_lower_bound_.HostVector(); } else if (!std::strcmp(key, "label_upper_bound")) { @@ -543,8 +661,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx()); this->labels_upper_bound_.Extend(that.labels_upper_bound_); - this->base_margin_.SetDevice(that.base_margin_.DeviceIdx()); - this->base_margin_.Extend(that.base_margin_); + linalg::Stack(&this->base_margin_, that.base_margin_); if (this->group_ptr_.size() == 0) { this->group_ptr_ = that.group_ptr_; @@ -627,12 +744,12 @@ void MetaInfo::Validate(int32_t device) const { if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(base_margin_); + check_device(*base_margin_.Data()); } } #if !defined(XGBOOST_USE_CUDA) -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { +void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) @@ -788,10 +905,10 @@ DMatrix* DMatrix::Load(const std::string& uri, LOG(CONSOLE) << info.group_ptr_.size() - 1 << " groups are loaded from " << fname << ".group"; } - if (MetaTryLoadFloatInfo - (fname + ".base_margin", &info.base_margin_.HostVector()) && !silent) { - LOG(CONSOLE) << info.base_margin_.Size() - << " base_margin are loaded from " << fname << ".base_margin"; + if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) && + !silent) { + LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname + << ".base_margin"; } if (MetaTryLoadFloatInfo (fname + ".weight", &info.weights_.HostVector()) && !silent) { diff --git a/src/data/data.cu b/src/data/data.cu index aee62d1b7fad..b3bb2ccc6d61 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -9,84 +9,79 @@ #include "xgboost/json.h" #include "array_interface.h" #include "../common/device_helpers.cuh" +#include "../common/linalg_op.cuh" #include "device_adapter.cuh" #include "simple_dmatrix.h" +#include "validation.h" namespace xgboost { - -void CopyInfoImpl(ArrayInterface column, HostDeviceVector* out) { - auto SetDeviceToPtr = [](void* ptr) { - cudaPointerAttributes attr; - dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); - int32_t ptr_device = attr.device; - if (ptr_device >= 0) { - dh::safe_cuda(cudaSetDevice(ptr_device)); - } - return ptr_device; - }; - auto ptr_device = SetDeviceToPtr(column.data); - - if (column.num_rows == 0) { - return; - } - out->SetDevice(ptr_device); - - size_t size = column.num_rows * column.num_cols; - CHECK_NE(size, 0); - out->Resize(size); - - auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); - dh::LaunchN(size, [=] __device__(size_t idx) { - size_t ridx = idx / column.num_cols; - size_t cidx = idx - (ridx * column.num_cols); - p_dst[idx] = column.GetElement(ridx, cidx); - }); -} - namespace { -auto SetDeviceToPtr(void *ptr) { +auto SetDeviceToPtr(void* ptr) { cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); int32_t ptr_device = attr.device; dh::safe_cuda(cudaSetDevice(ptr_device)); return ptr_device; } -} // anonymous namespace -void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { - CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8) +template +void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor* p_out) { + ArrayInterface array(arr_interface); + if (array.n == 0) { + return; + } + CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + auto ptr_device = SetDeviceToPtr(array.data); + + if (array.is_contiguous && array.type == ToDType::kType) { + p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + // set shape + std::copy(array.shape, array.shape + D, shape.data()); + // set data + data->SetDevice(ptr_device); + data->Resize(array.n); + dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), + cudaMemcpyDefault)); + }); + return; + } + p_out->SetDevice(ptr_device); + p_out->Reshape(array.shape); + auto t = p_out->View(ptr_device); + linalg::ElementWiseKernelDevice(t, [=] __device__(size_t i, T) { + return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, array.shape)); + }); +} + +void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector* out) { + CHECK(column.type != ArrayInterfaceHandler::kF4 && column.type != ArrayInterfaceHandler::kF8) << "Expected integer for group info."; auto ptr_device = SetDeviceToPtr(column.data); CHECK_EQ(ptr_device, dh::CurrentDevice()); - dh::TemporaryArray temp(column.num_rows); - auto d_tmp = temp.data(); + dh::TemporaryArray temp(column.Shape(0)); + auto d_tmp = temp.data().get(); - dh::LaunchN(column.num_rows, [=] __device__(size_t idx) { - d_tmp[idx] = column.GetElement(idx, 0); - }); - auto length = column.num_rows; + dh::LaunchN(column.Shape(0), + [=] __device__(size_t idx) { d_tmp[idx] = column.operator()(idx); }); + auto length = column.Shape(0); out->resize(length + 1); out->at(0) = 0; thrust::copy(temp.data(), temp.data() + length, out->begin() + 1); std::partial_sum(out->begin(), out->end(), out->begin()); } -void CopyQidImpl(ArrayInterface array_interface, - std::vector *p_group_ptr) { +void CopyQidImpl(ArrayInterface<1> array_interface, std::vector* p_group_ptr) { auto &group_ptr_ = *p_group_ptr; auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), - [array_interface] __device__(size_t i) { - return array_interface.GetElement(i, 0); - }); + [array_interface] __device__(size_t i) { return array_interface.operator()(i); }); dh::caching_device_vector flag(1); auto d_flag = dh::ToSpan(flag); auto d = SetDeviceToPtr(array_interface.data); dh::LaunchN(1, [=] __device__(size_t) { d_flag[0] = true; }); - dh::LaunchN(array_interface.num_rows - 1, [=] __device__(size_t i) { - if (array_interface.GetElement(i, 0) > - array_interface.GetElement(i + 1, 0)) { + dh::LaunchN(array_interface.Shape(0) - 1, [=] __device__(size_t i) { + if (array_interface.operator()(i) > array_interface.operator()(i + 1)) { d_flag[0] = false; } }); @@ -95,16 +90,16 @@ void CopyQidImpl(ArrayInterface array_interface, cudaMemcpyDeviceToHost)); CHECK(non_dec) << "`qid` must be sorted in increasing order along with data."; size_t bytes = 0; - dh::caching_device_vector out(array_interface.num_rows); - dh::caching_device_vector cnt(array_interface.num_rows); + dh::caching_device_vector out(array_interface.Shape(0)); + dh::caching_device_vector cnt(array_interface.Shape(0)); HostDeviceVector d_num_runs_out(1, 0, d); cub::DeviceRunLengthEncode::Encode( nullptr, bytes, it, out.begin(), cnt.begin(), - d_num_runs_out.DevicePointer(), array_interface.num_rows); + d_num_runs_out.DevicePointer(), array_interface.Shape(0)); dh::caching_device_vector tmp(bytes); cub::DeviceRunLengthEncode::Encode( tmp.data().get(), bytes, it, out.begin(), cnt.begin(), - d_num_runs_out.DevicePointer(), array_interface.num_rows); + d_num_runs_out.DevicePointer(), array_interface.Shape(0)); auto h_num_runs_out = d_num_runs_out.HostSpan()[0]; group_ptr_.clear(); @@ -115,77 +110,52 @@ void CopyQidImpl(ArrayInterface array_interface, thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out, group_ptr_.begin() + 1); } +} // namespace -namespace { -// thrust::all_of tries to copy lambda function. -struct LabelsCheck { - __device__ bool operator()(float y) { return ::isnan(y) || ::isinf(y); } -}; -struct WeightsCheck { - __device__ bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT -}; -} // anonymous namespace - -void ValidateQueryGroup(std::vector const &group_ptr_); - -void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { - Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface array_interface(interface_str); - std::string key{c_key}; - - CHECK(!array_interface.valid.Data()) - << "Meta info " << key << " should be dense, found validity mask"; - if (array_interface.num_rows == 0) { - return; - } - +void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { + // multi-dim float info if (key == "base_margin") { - CopyInfoImpl(array_interface, &base_margin_); + CopyTensorInfoImpl(array, &base_margin_); return; } - - CHECK(array_interface.num_cols == 1 || array_interface.num_rows == 1) - << "MetaInfo: " << c_key << " has invalid shape"; - if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) || - (array_interface.num_cols == 0 && array_interface.num_rows == 1))) { - // Not an empty column, transform it. - array_interface.AsColumnVector(); + // uint info + if (key == "group") { + auto array_interface{ArrayInterface<1>(array)}; + CopyGroupInfoImpl(array_interface, &group_ptr_); + data::ValidateQueryGroup(group_ptr_); + return; + } else if (key == "qid") { + auto array_interface{ArrayInterface<1>(array)}; + CopyQidImpl(array_interface, &group_ptr_); + data::ValidateQueryGroup(group_ptr_); + return; } + // float info + linalg::Tensor t; + CopyTensorInfoImpl(array, &t); if (key == "label") { - CopyInfoImpl(array_interface, &labels_); + this->labels_ = std::move(*t.Data()); auto ptr = labels_.ConstDevicePointer(); - auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), - LabelsCheck{}); + auto valid = thrust::none_of(thrust::device, ptr, ptr + labels_.Size(), data::LabelsCheck{}); CHECK(valid) << "Label contains NaN, infinity or a value too large."; } else if (key == "weight") { - CopyInfoImpl(array_interface, &weights_); + this->weights_ = std::move(*t.Data()); auto ptr = weights_.ConstDevicePointer(); - auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), - WeightsCheck{}); + auto valid = thrust::none_of(thrust::device, ptr, ptr + weights_.Size(), data::WeightsCheck{}); CHECK(valid) << "Weights must be positive values."; - } else if (key == "group") { - CopyGroupInfoImpl(array_interface, &group_ptr_); - ValidateQueryGroup(group_ptr_); - return; - } else if (key == "qid") { - CopyQidImpl(array_interface, &group_ptr_); - return; } else if (key == "label_lower_bound") { - CopyInfoImpl(array_interface, &labels_lower_bound_); - return; + this->labels_lower_bound_ = std::move(*t.Data()); } else if (key == "label_upper_bound") { - CopyInfoImpl(array_interface, &labels_upper_bound_); - return; + this->labels_upper_bound_ = std::move(*t.Data()); } else if (key == "feature_weights") { - CopyInfoImpl(array_interface, &feature_weights); + this->feature_weights = std::move(*t.Data()); auto d_feature_weights = feature_weights.ConstDeviceSpan(); - auto valid = thrust::none_of( - thrust::device, d_feature_weights.data(), - d_feature_weights.data() + d_feature_weights.size(), WeightsCheck{}); + auto valid = + thrust::none_of(thrust::device, d_feature_weights.data(), + d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{}); CHECK(valid) << "Feature weight must be greater than 0."; - return; } else { - LOG(FATAL) << "Unknown metainfo: " << key; + LOG(FATAL) << "Unknown key for MetaInfo: " << key; } } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 628878f319f1..d1bda280a7d5 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -20,7 +20,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { public: CudfAdapterBatch() = default; - CudfAdapterBatch(common::Span columns, size_t num_rows) + CudfAdapterBatch(common::Span> columns, size_t num_rows) : columns_(columns), num_rows_(num_rows) {} size_t Size() const { return num_rows_ * columns_.size(); } @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { size_t row_idx = idx / columns_.size(); auto const& column = columns_[column_idx]; float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) - ? column.GetElement(row_idx, 0) + ? column(row_idx) : std::numeric_limits::quiet_NaN(); return {row_idx, column_idx, value}; } @@ -38,7 +38,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } private: - common::Span columns_; + common::Span> columns_; size_t num_rows_; }; @@ -101,9 +101,9 @@ class CudfAdapter : public detail::SingleBatchDataIter { auto const& typestr = get(json_columns[0]["typestr"]); CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); - std::vector columns; - auto first_column = ArrayInterface(get(json_columns[0])); - num_rows_ = first_column.num_rows; + std::vector> columns; + auto first_column = ArrayInterface<1>(get(json_columns[0])); + num_rows_ = first_column.Shape(0); if (num_rows_ == 0) { return; } @@ -112,13 +112,12 @@ class CudfAdapter : public detail::SingleBatchDataIter { CHECK_NE(device_idx_, -1); dh::safe_cuda(cudaSetDevice(device_idx_)); for (auto& json_col : json_columns) { - auto column = ArrayInterface(get(json_col)); + auto column = ArrayInterface<1>(get(json_col)); columns.push_back(column); - CHECK_EQ(column.num_cols, 1); - num_rows_ = std::max(num_rows_, size_t(column.num_rows)); + num_rows_ = std::max(num_rows_, size_t(column.Shape(0))); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; - CHECK_EQ(num_rows_, column.num_rows) + CHECK_EQ(num_rows_, column.Shape(0)) << "All columns should have same number of rows."; } columns_ = columns; @@ -135,7 +134,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { private: CudfAdapterBatch batch_; - dh::device_vector columns_; + dh::device_vector> columns_; size_t num_rows_{0}; int device_idx_; }; @@ -143,23 +142,23 @@ class CudfAdapter : public detail::SingleBatchDataIter { class CupyAdapterBatch : public detail::NoMetaInfo { public: CupyAdapterBatch() = default; - explicit CupyAdapterBatch(ArrayInterface array_interface) + explicit CupyAdapterBatch(ArrayInterface<2> array_interface) : array_interface_(std::move(array_interface)) {} size_t Size() const { - return array_interface_.num_rows * array_interface_.num_cols; + return array_interface_.Shape(0) * array_interface_.Shape(1); } __device__ COOTuple GetElement(size_t idx) const { - size_t column_idx = idx % array_interface_.num_cols; - size_t row_idx = idx / array_interface_.num_cols; - float value = array_interface_.GetElement(row_idx, column_idx); + size_t column_idx = idx % array_interface_.Shape(1); + size_t row_idx = idx / array_interface_.Shape(1); + float value = array_interface_(row_idx, column_idx); return {row_idx, column_idx, value}; } - XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.num_rows; } - XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.num_cols; } + XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); } + XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); } private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; }; class CupyAdapter : public detail::SingleBatchDataIter { @@ -167,9 +166,9 @@ class CupyAdapter : public detail::SingleBatchDataIter { explicit CupyAdapter(std::string cuda_interface_str) { Json json_array_interface = Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()}); - array_interface_ = ArrayInterface(get(json_array_interface), false); + array_interface_ = ArrayInterface<2>(get(json_array_interface)); batch_ = CupyAdapterBatch(array_interface_); - if (array_interface_.num_rows == 0) { + if (array_interface_.Shape(0) == 0) { return; } device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); @@ -177,12 +176,12 @@ class CupyAdapter : public detail::SingleBatchDataIter { } const CupyAdapterBatch& Value() const override { return batch_; } - size_t NumRows() const { return array_interface_.num_rows; } - size_t NumColumns() const { return array_interface_.num_cols; } + size_t NumRows() const { return array_interface_.Shape(0); } + size_t NumColumns() const { return array_interface_.Shape(1); } int32_t DeviceIdx() const { return device_idx_; } private: - ArrayInterface array_interface_; + ArrayInterface<2> array_interface_; CupyAdapterBatch batch_; int32_t device_idx_ {-1}; }; diff --git a/src/data/file_iterator.h b/src/data/file_iterator.h index 6d6adb62b008..70a5d51c30b9 100644 --- a/src/data/file_iterator.h +++ b/src/data/file_iterator.h @@ -12,6 +12,7 @@ #include "dmlc/data.h" #include "xgboost/c_api.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "array_interface.h" namespace xgboost { @@ -58,16 +59,14 @@ class FileIterator { CHECK(parser_); if (parser_->Next()) { row_block_ = parser_->Value(); + using linalg::MakeVec; - indptr_ = MakeArrayInterface(row_block_.offset, row_block_.size + 1); - values_ = MakeArrayInterface(row_block_.value, - row_block_.offset[row_block_.size]); - indices_ = MakeArrayInterface(row_block_.index, - row_block_.offset[row_block_.size]); + indptr_ = MakeVec(row_block_.offset, row_block_.size + 1).ArrayInterfaceStr(); + values_ = MakeVec(row_block_.value, row_block_.offset[row_block_.size]).ArrayInterfaceStr(); + indices_ = MakeVec(row_block_.index, row_block_.offset[row_block_.size]).ArrayInterfaceStr(); - size_t n_columns = *std::max_element( - row_block_.index, - row_block_.index + row_block_.offset[row_block_.size]); + size_t n_columns = *std::max_element(row_block_.index, + row_block_.index + row_block_.offset[row_block_.size]); // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore // this condition and just add 1 to n_columns n_columns += 1; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 44a8a3f8fe7c..e83559d3958a 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -137,9 +137,10 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - auto& base_margin = info_.base_margin_.HostVector(); - base_margin.insert(base_margin.end(), batch.BaseMargin(), - batch.BaseMargin() + batch.Size()); + info_.base_margin_ = linalg::Tensor{batch.BaseMargin(), + batch.BaseMargin() + batch.Size(), + {batch.Size()}, + GenericParameter::kCpuId}; } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); diff --git a/src/data/validation.h b/src/data/validation.h new file mode 100644 index 000000000000..6d3701114886 --- /dev/null +++ b/src/data/validation.h @@ -0,0 +1,40 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_DATA_VALIDATION_H_ +#define XGBOOST_DATA_VALIDATION_H_ +#include +#include + +#include "xgboost/base.h" +#include "xgboost/logging.h" + +namespace xgboost { +namespace data { +struct LabelsCheck { + XGBOOST_DEVICE bool operator()(float y) { +#if defined(__CUDA_ARCH__) + return ::isnan(y) || ::isinf(y); +#else + return std::isnan(y) || std::isinf(y); +#endif + } +}; + +struct WeightsCheck { + XGBOOST_DEVICE bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT +}; + +inline void ValidateQueryGroup(std::vector const &group_ptr_) { + bool valid_query_group = true; + for (size_t i = 1; i < group_ptr_.size(); ++i) { + valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1]; + if (XGBOOST_EXPECT(!valid_query_group, false)) { + break; + } + } + CHECK(valid_query_group) << "Invalid group structure."; +} +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_VALIDATION_H_ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index e5f2a457214c..e5f5916211cc 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -178,7 +178,7 @@ class GBLinear : public GradientBooster { unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override { model_.LazyInitModel(); LinearCheckLayer(layer_begin, layer_end); - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); const int ngroup = model_.learner_model_param->num_output_group; const size_t ncolumns = model_.learner_model_param->num_feature + 1; // allocate space for (#features + bias) times #groups times #rows @@ -203,9 +203,9 @@ class GBLinear : public GradientBooster { p_contribs[ins.index] = ins.fvalue * model_[ins.index][gid]; } // add base margin to BIAS - p_contribs[ncolumns - 1] = model_.Bias()[gid] + - ((base_margin.size() != 0) ? base_margin[row_idx * ngroup + gid] : - learner_model_param_->base_score); + p_contribs[ncolumns - 1] = + model_.Bias()[gid] + ((base_margin.Size() != 0) ? base_margin(row_idx, gid) + : learner_model_param_->base_score); } }); } @@ -270,7 +270,7 @@ class GBLinear : public GradientBooster { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); std::vector &preds = *out_preds; - const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); + auto base_margin = p_fmat->Info().base_margin_.View(GenericParameter::kCpuId); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); @@ -280,16 +280,15 @@ class GBLinear : public GradientBooster { // k is number of group // parallel over local batch const auto nsize = static_cast(batch.Size()); - if (base_margin.size() != 0) { - CHECK_EQ(base_margin.size(), nsize * ngroup); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Size(), nsize * ngroup); } common::ParallelFor(nsize, [&](omp_ulong i) { const size_t ridx = page.base_rowid + i; // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float margin = - (base_margin.size() != 0) ? - base_margin[ridx * ngroup + gid] : learner_model_param_->base_score; + float margin = + (base_margin.Size() != 0) ? base_margin(ridx, gid) : learner_model_param_->base_score; this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); } }); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index d581f64a1d56..92797235d16b 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -282,27 +282,6 @@ class CPUPredictor : public Predictor { } } - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n = model.learner_model_param->num_output_group * info.num_row_; - const auto& base_margin = info.base_margin_.HostVector(); - out_preds->Resize(n); - std::vector& out_preds_h = out_preds->HostVector(); - if (base_margin.empty()) { - std::fill(out_preds_h.begin(), out_preds_h.end(), - model.learner_model_param->base_score); - } else { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); - } - } - public: explicit CPUPredictor(GenericParameter const* generic_param) : Predictor::Predictor{generic_param} {} @@ -456,7 +435,7 @@ class CPUPredictor : public Predictor { common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); - const std::vector& base_margin = info.base_margin_.HostVector(); + auto base_margin = info.base_margin_.View(GenericParameter::kCpuId); // start collecting the contributions for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); @@ -496,8 +475,9 @@ class CPUPredictor : public Predictor { } feats.Drop(page[i]); // add base margin to BIAS - if (base_margin.size() != 0) { - p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), ngroup); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); } else { p_contribs[ncolumns - 1] += model.learner_model_param->base_score; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 51674237e973..025a1c495c52 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -854,8 +854,8 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + p_fmat->Info().base_margin_.Data()->SetDevice(generic_param_->gpu_id); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; dh::LaunchN( p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, @@ -913,8 +913,8 @@ class GPUPredictor : public xgboost::Predictor { dh::tend(phis)); } // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); - const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + p_fmat->Info().base_margin_.Data()->SetDevice(generic_param_->gpu_id); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; size_t n_features = model.learner_model_param->num_feature; dh::LaunchN( @@ -928,27 +928,6 @@ class GPUPredictor : public xgboost::Predictor { }); } - protected: - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const override { - size_t n_classes = model.learner_model_param->num_output_group; - size_t n = n_classes * info.num_row_; - const HostDeviceVector& base_margin = info.base_margin_; - out_preds->SetDevice(generic_param_->gpu_id); - out_preds->Resize(n); - if (base_margin.Size() != 0) { - std::string expected{ - "(" + std::to_string(info.num_row_) + ", " + - std::to_string(model.learner_model_param->num_output_group) + ")"}; - CHECK_EQ(base_margin.Size(), n) - << "Invalid shape of base_margin. Expected:" << expected; - out_preds->Copy(base_margin); - } else { - out_preds->Fill(model.learner_model_param->base_score); - } - } - void PredictInstance(const SparsePage::Inst&, std::vector*, const gbm::GBTreeModel&, unsigned) const override { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 9aa18b19ce30..e9e10a632a7f 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 by Contributors + * Copyright 2017-2021 by Contributors */ #include #include @@ -8,6 +8,8 @@ #include "xgboost/data.h" #include "xgboost/generic_parameters.h" +#include "../gbm/gbtree.h" + namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc @@ -58,6 +60,33 @@ Predictor* Predictor::Create( auto p_predictor = (e->body)(generic_param); return p_predictor; } + +void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n_samples, + bst_group_t n_groups) { + // FIXME: Bindings other than Python doesn't have shape. + std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) + + ", " + std::to_string(n_groups) + ")"}; + CHECK_EQ(margin.Shape(0), n_samples) << expected; + CHECK_EQ(margin.Shape(1), n_groups) << expected; +} + +void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, + const gbm::GBTreeModel& model) const { + CHECK_NE(model.learner_model_param->num_output_group, 0); + size_t n_classes = model.learner_model_param->num_output_group; + size_t n = n_classes * info.num_row_; + const HostDeviceVector* base_margin = info.base_margin_.Data(); + if (generic_param_->gpu_id >= 0) { + out_preds->SetDevice(generic_param_->gpu_id); + } + out_preds->Resize(n); + if (base_margin->Size() != 0) { + ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); + out_preds->Copy(*base_margin); + } else { + out_preds->Fill(model.learner_model_param->base_score); + } +} } // namespace xgboost namespace xgboost { diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index c7fc4972d3ee..49952d202706 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -57,7 +57,7 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, Json(Integer(reinterpret_cast(x.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; - array_interface["version"] = Integer(static_cast(1)); + array_interface["version"] = 3; array_interface["typestr"] = String(" indices; size_t n_features = 100, n_samples = 10; RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices); - auto indptr_arr = MakeArrayInterface(indptr.HostPointer(), indptr.Size()); - auto values_arr = MakeArrayInterface(values.HostPointer(), values.Size()); - auto indices_arr = MakeArrayInterface(indices.HostPointer(), indices.Size()); + using linalg::MakeVec; + auto indptr_arr = MakeVec(indptr.HostPointer(), indptr.Size()).ArrayInterfaceStr(); + auto values_arr = MakeVec(values.HostPointer(), values.Size()).ArrayInterfaceStr(); + auto indices_arr = MakeVec(indices.HostPointer(), indices.Size()).ArrayInterfaceStr(); auto adapter = data::CSRArrayAdapter( StringView{indptr_arr.c_str(), indptr_arr.size()}, StringView{values_arr.c_str(), values_arr.size()}, diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc index 875858855ed5..f438593b42a4 100644 --- a/tests/cpp/data/test_array_interface.cc +++ b/tests/cpp/data/test_array_interface.cc @@ -11,21 +11,22 @@ TEST(ArrayInterface, Initialize) { size_t constexpr kRows = 10, kCols = 10; HostDeviceVector storage; auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - auto arr_interface = ArrayInterface(array); - ASSERT_EQ(arr_interface.num_rows, kRows); - ASSERT_EQ(arr_interface.num_cols, kCols); + auto arr_interface = ArrayInterface<2>(StringView{array}); + ASSERT_EQ(arr_interface.Shape(0), kRows); + ASSERT_EQ(arr_interface.Shape(1), kCols); ASSERT_EQ(arr_interface.data, storage.ConstHostPointer()); ASSERT_EQ(arr_interface.ElementSize(), 4); - ASSERT_EQ(arr_interface.type, ArrayInterface::kF4); + ASSERT_EQ(arr_interface.type, ArrayInterfaceHandler::kF4); HostDeviceVector u64_storage(storage.Size()); - std::string u64_arr_str; - Json::Dump(GetArrayInterface(&u64_storage, kRows, kCols), &u64_arr_str); + std::string u64_arr_str{linalg::TensorView{ + u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId} + .ArrayInterfaceStr()}; std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(), u64_storage.HostSpan().begin()); - auto u64_arr = ArrayInterface{u64_arr_str}; + auto u64_arr = ArrayInterface<2>{u64_arr_str}; ASSERT_EQ(u64_arr.ElementSize(), 8); - ASSERT_EQ(u64_arr.type, ArrayInterface::kU8); + ASSERT_EQ(u64_arr.type, ArrayInterfaceHandler::kU8); } TEST(ArrayInterface, Error) { @@ -38,23 +39,22 @@ TEST(ArrayInterface, Error) { Json(Boolean(false))}; auto const& column_obj = get(column); - std::pair shape{kRows, kCols}; std::string typestr{"(1)); + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); + column["version"] = 3; // missing data - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape), + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); column["data"] = j_data; // missing typestr - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape), + EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); column["typestr"] = String(" storage; @@ -63,22 +63,41 @@ TEST(ArrayInterface, Error) { Json(Integer(reinterpret_cast(storage.ConstHostPointer()))), Json(Boolean(false))}; column["data"] = j_data; - EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, shape)); + EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n)); } TEST(ArrayInterface, GetElement) { size_t kRows = 4, kCols = 2; HostDeviceVector storage; auto intefrace_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - ArrayInterface array_interface{intefrace_str}; + ArrayInterface<2> array_interface{intefrace_str}; auto const& h_storage = storage.ConstHostVector(); for (size_t i = 0; i < kRows; ++i) { for (size_t j = 0; j < kCols; ++j) { - float v0 = array_interface.GetElement(i, j); + float v0 = array_interface(i, j); float v1 = h_storage.at(i * kCols + j); ASSERT_EQ(v0, v1); } } } + +TEST(ArrayInterface, TrivialDim) { + size_t kRows{1000}, kCols = 1; + HostDeviceVector storage; + auto interface_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + { + ArrayInterface<1> arr_i{interface_str}; + ASSERT_EQ(arr_i.n, kRows); + ASSERT_EQ(arr_i.Shape(0), kRows); + } + + std::swap(kRows, kCols); + interface_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); + { + ArrayInterface<1> arr_i{interface_str}; + ASSERT_EQ(arr_i.n, kCols); + ASSERT_EQ(arr_i.Shape(0), kCols); + } +} } // namespace xgboost diff --git a/tests/cpp/data/test_array_interface.cu b/tests/cpp/data/test_array_interface.cu index 75923e74ba1a..c8e07852534b 100644 --- a/tests/cpp/data/test_array_interface.cu +++ b/tests/cpp/data/test_array_interface.cu @@ -32,11 +32,24 @@ TEST(ArrayInterface, Stream) { dh::caching_device_vector out(1, 0); uint64_t dur = 1e9; dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur); - ArrayInterface arr(arr_str); + ArrayInterface<2> arr(arr_str); auto t = out[0]; CHECK_GE(t, dur); cudaStreamDestroy(stream); } + +TEST(ArrayInterface, Ptr) { + std::vector h_data(10); + ASSERT_FALSE(ArrayInterfaceHandler::IsCudaPtr(h_data.data())); + dh::safe_cuda(cudaGetLastError()); + + dh::device_vector d_data(10); + ASSERT_TRUE(ArrayInterfaceHandler::IsCudaPtr(d_data.data().get())); + dh::safe_cuda(cudaGetLastError()); + + ASSERT_FALSE(ArrayInterfaceHandler::IsCudaPtr(nullptr)); + dh::safe_cuda(cudaGetLastError()); +} } // namespace xgboost diff --git a/tests/cpp/data/test_array_interface.h b/tests/cpp/data/test_array_interface.h index 7872a9507aa5..78bce76f53e7 100644 --- a/tests/cpp/data/test_array_interface.h +++ b/tests/cpp/data/test_array_interface.h @@ -19,6 +19,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows, std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); + column["stream"] = nullptr; d_data.resize(kRows); thrust::sequence(thrust::device, d_data.begin(), d_data.end(), 0.0f, 2.0f); @@ -30,7 +31,7 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows, Json(Boolean(false))}; column["data"] = j_data; - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); return column; } @@ -43,6 +44,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows, std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); + column["stream"] = nullptr; d_data.resize(kRows); for (size_t i = 0; i < d_data.size(); ++i) { @@ -56,7 +58,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows, Json(Boolean(false))}; column["data"] = j_data; - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); return column; } @@ -75,9 +77,9 @@ Json Generate2dArrayInterface(int rows, int cols, std::string typestr, Json(Integer(reinterpret_cast(data.data().get()))), Json(Boolean(false))}; array_interface["data"] = j_data; - array_interface["version"] = Integer(static_cast(1)); + array_interface["version"] = 3; array_interface["typestr"] = String(typestr); + array_interface["stream"] = nullptr; return array_interface; } - } // namespace xgboost diff --git a/tests/cpp/data/test_iterative_device_dmatrix.cu b/tests/cpp/data/test_iterative_device_dmatrix.cu index cb64a3b5cdb2..27f6b0b3ffe9 100644 --- a/tests/cpp/data/test_iterative_device_dmatrix.cu +++ b/tests/cpp/data/test_iterative_device_dmatrix.cu @@ -103,7 +103,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) { auto j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface loaded {get(j_interface)}; + ArrayInterface<2> loaded {get(j_interface)}; std::vector h_data(cols * rows); common::Span s_data{static_cast(loaded.data), cols * rows}; dh::CopyDeviceSpanToVector(&h_data, s_data); @@ -128,7 +128,7 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) { std::string interface_str = iter.AsArray(); auto j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); - ArrayInterface loaded {get(j_interface)}; + ArrayInterface<2> loaded {get(j_interface)}; std::vector h_data(cols * rows); common::Span s_data{static_cast(loaded.data), cols * rows}; dh::CopyDeviceSpanToVector(&h_data, s_data); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index bb5452a56d28..4f379f4fdd25 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -1,4 +1,6 @@ -// Copyright 2016-2020 by Contributors +// Copyright 2016-2021 by Contributors +#include "test_metainfo.h" + #include #include #include @@ -122,7 +124,10 @@ TEST(MetaInfo, SaveLoadBinary) { EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); EXPECT_EQ(inforead.group_ptr_, info.group_ptr_); EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector()); - EXPECT_EQ(inforead.base_margin_.HostVector(), info.base_margin_.HostVector()); + + auto orig_margin = info.base_margin_.View(xgboost::GenericParameter::kCpuId); + auto read_margin = inforead.base_margin_.View(xgboost::GenericParameter::kCpuId); + EXPECT_TRUE(std::equal(orig_margin.cbegin(), orig_margin.cend(), read_margin.cbegin())); EXPECT_EQ(inforead.feature_type_names.size(), kCols); EXPECT_EQ(inforead.feature_types.Size(), kCols); @@ -254,10 +259,10 @@ TEST(MetaInfo, Validate) { xgboost::HostDeviceVector d_groups{groups}; d_groups.SetDevice(0); d_groups.DevicePointer(); // pull to device - auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1); - std::string arr_interface_str; - xgboost::Json::Dump(arr_interface, &arr_interface_str); - EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error); + std::string arr_interface_str{ + xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0) + .ArrayInterfaceStr()}; + EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } @@ -292,3 +297,7 @@ TEST(MetaInfo, HostExtend) { ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i); } } + +namespace xgboost { +TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); } +} // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 090374b913d6..d57e9ceda807 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -1,12 +1,14 @@ -/*! Copyright 2019 by Contributors */ - +/*! Copyright 2019-2021 by XGBoost Contributors */ #include #include #include +#include #include #include "test_array_interface.h" #include "../../../src/common/device_helpers.cuh" +#include "test_metainfo.h" + namespace xgboost { template @@ -23,7 +25,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); column["strides"] = Array(std::vector{Json(Integer(static_cast(sizeof(T))))}); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(typestr); auto p_d_data = d_data.data().get(); @@ -31,6 +33,7 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons Json(Integer(reinterpret_cast(p_d_data))), Json(Boolean(false))}; column["data"] = j_data; + column["stream"] = nullptr; Json array(std::vector{column}); std::string str; @@ -49,6 +52,7 @@ TEST(MetaInfo, FromInterface) { info.SetInfo("label", str.c_str()); auto const& h_label = info.labels_.HostVector(); + ASSERT_EQ(h_label.size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { ASSERT_EQ(h_label[i], d_data[i]); } @@ -60,9 +64,10 @@ TEST(MetaInfo, FromInterface) { } info.SetInfo("base_margin", str.c_str()); - auto const& h_base_margin = info.base_margin_.HostVector(); + auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId); + ASSERT_EQ(h_base_margin.Size(), d_data.size()); for (size_t i = 0; i < d_data.size(); ++i) { - ASSERT_EQ(h_base_margin[i], d_data[i]); + ASSERT_EQ(h_base_margin(i), d_data[i]); } thrust::device_vector d_group_data; @@ -76,6 +81,10 @@ TEST(MetaInfo, FromInterface) { EXPECT_EQ(info.group_ptr_, expected_group_ptr); } +TEST(MetaInfo, GPUStridedData) { + TestMetaInfoStridedData(0); +} + TEST(MetaInfo, Group) { cudaSetDevice(0); MetaInfo info; diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h new file mode 100644 index 000000000000..67da633d4be5 --- /dev/null +++ b/tests/cpp/data/test_metainfo.h @@ -0,0 +1,82 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ +#include +#include +#include +#include + +#include +#include "../../../src/data/array_interface.h" +#include "../../../src/common/linalg_op.h" + +namespace xgboost { +inline void TestMetaInfoStridedData(int32_t device) { + MetaInfo info; + { + // label + HostDeviceVector labels; + labels.Resize(64); + auto& h_labels = labels.HostVector(); + std::iota(h_labels.begin(), h_labels.end(), 0.0f); + bool is_gpu = device >= 0; + if (is_gpu) { + labels.SetDevice(0); + } + + auto t = linalg::TensorView{ + is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device}; + auto s = t.Slice(linalg::All(), 0); + + auto str = s.ArrayInterfaceStr(); + ASSERT_EQ(s.Size(), 32); + + info.SetInfo("label", StringView{str}); + auto const& h_result = info.labels_.HostVector(); + ASSERT_EQ(h_result.size(), 32); + + for (auto v : h_result) { + ASSERT_EQ(static_cast(v) % 2, 0); + } + } + { + // qid + linalg::Tensor qid; + qid.Reshape(32, 2); + auto& h_qid = qid.Data()->HostVector(); + std::iota(h_qid.begin(), h_qid.end(), 0); + auto s = qid.View(device).Slice(linalg::All(), 0); + auto str = s.ArrayInterfaceStr(); + info.SetInfo("qid", StringView{str}); + auto const& h_result = info.group_ptr_; + ASSERT_EQ(h_result.size(), s.Size() + 1); + } + { + // base margin + linalg::Tensor base_margin; + base_margin.Reshape(4, 3, 2, 3); + auto& h_margin = base_margin.Data()->HostVector(); + std::iota(h_margin.begin(), h_margin.end(), 0.0); + auto t_margin = base_margin.View(device).Slice(linalg::All(), linalg::All(), 0, linalg::All()); + ASSERT_EQ(t_margin.Shape().size(), 3); + + info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()}); + auto const& h_result = info.base_margin_.View(-1); + ASSERT_EQ(h_result.Shape().size(), 3); + auto in_margin = base_margin.View(-1); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { + auto tup = linalg::UnravelIndex(i, h_result.Shape()); + auto i0 = std::get<0>(tup); + auto i1 = std::get<1>(tup); + auto i2 = std::get<2>(tup); + // Sliced at 3^th dimension. + auto v_1 = in_margin(i0, i1, 0, i2); + CHECK_EQ(v_0, v_1); + return v_0; + }); + } +} +} // namespace xgboost +#endif // XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index f777b00e244c..c25e877079d2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -253,8 +253,8 @@ TEST(SimpleDMatrix, Slice) { std::iota(lower.begin(), lower.end(), 0.0f); std::iota(upper.begin(), upper.end(), 1.0f); - auto& margin = p_m->Info().base_margin_.HostVector(); - margin.resize(kRows * kClasses); + auto& margin = p_m->Info().base_margin_; + margin = linalg::Tensor{{kRows, kClasses}, GenericParameter::kCpuId}; std::array ridxs {1, 3, 5}; std::unique_ptr out { p_m->Slice(ridxs) }; @@ -284,10 +284,10 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx), out->Info().weights_.HostVector().at(i)); - auto& out_margin = out->Info().base_margin_.HostVector(); + auto out_margin = out->Info().base_margin_.View(GenericParameter::kCpuId); + auto in_margin = margin.View(GenericParameter::kCpuId); for (size_t j = 0; j < kClasses; ++j) { - auto in_beg = ridx * kClasses; - ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j)); + ASSERT_EQ(out_margin(i, j), in_margin(ridx, j)); } } } diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index d74f5b150087..19f13b1fddfd 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -122,13 +122,13 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) { col["data"] = j_data; std::vector j_shape{Json(Integer(static_cast(kRows)))}; col["shape"] = Array(j_shape); - col["version"] = Integer(static_cast(1)); + col["version"] = 3; col["typestr"] = String("(1)); + j_mask["version"] = 3; auto& mask_storage = column_bitfields[i]; mask_storage.resize(16); // 16 bytes @@ -220,7 +220,7 @@ TEST(SimpleCSRSource, FromColumnarSparse) { for (size_t c = 0; c < kCols; ++c) { auto& column = j_columns[c]; column = Object(); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String(" j_data { @@ -229,12 +229,12 @@ TEST(SimpleCSRSource, FromColumnarSparse) { column["data"] = j_data; std::vector j_shape {Json(Integer(static_cast(kRows)))}; column["shape"] = Array(j_shape); - column["version"] = Integer(static_cast(1)); + column["version"] = 3; column["typestr"] = String("(1)); + j_mask["version"] = 3; j_mask["data"] = std::vector{ Json(Integer(reinterpret_cast(column_bitfields[c].data().get()))), Json(Boolean(false))}; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index db4454c9eb1e..0906d9ed87df 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -228,6 +228,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( if (device_ >= 0) { array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer() + offset)); + array_interface["stream"] = Null{}; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->HostPointer() + offset)); @@ -240,7 +241,7 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( array_interface["shape"][1] = cols_; array_interface["typestr"] = String(" *storage, size_t rows, size_t cols) { if (storage->DeviceCanRead()) { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstDevicePointer())); + array_interface["stream"] = nullptr; } else { array_interface["data"][0] = Integer(reinterpret_cast(storage->ConstHostPointer())); @@ -198,9 +199,9 @@ Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { array_interface["shape"][0] = rows; array_interface["shape"][1] = cols; - char t = ArrayInterfaceHandler::TypeChar(); + char t = linalg::detail::ArrayInterfaceHandler::TypeChar(); array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T))); - array_interface["version"] = 1; + array_interface["version"] = 3; return array_interface; } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index ad1083f9161b..b36df742da4f 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -108,7 +108,9 @@ TEST(GPUPredictor, ExternalMemoryTest) { dmats.push_back(CreateSparsePageDMatrix(8000)); for (const auto& dmat: dmats) { - dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); + dmat->Info().base_margin_ = + linalg::Tensor{{dmat->Info().num_row_, static_cast(n_classes)}, 0}; + dmat->Info().base_margin_.Data()->Fill(0.5); PredictionCacheEntry out_predictions; gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 1b5ad266d3d9..9dcaa1c0d2c1 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -141,6 +141,7 @@ def test_slice(self): # base margin is per-class in multi-class classifier base_margin = rng.randn(100, 3).astype(np.float32) d.set_base_margin(base_margin) + np.testing.assert_allclose(d.get_base_margin().reshape(100, 3), base_margin) ridxs = [1, 2, 3, 4, 5, 6] sliced = d.slice(ridxs) @@ -154,7 +155,7 @@ def test_slice(self): # Slicing a DMatrix results into a DMatrix that's equivalent to a DMatrix that's # constructed from the corresponding NumPy slice d2 = xgb.DMatrix(X[1:7, :], y[1:7]) - d2.set_base_margin(base_margin[1:7, :].flatten()) + d2.set_base_margin(base_margin[1:7, :]) eval_res = {} _ = xgb.train( {'num_class': 3, 'objective': 'multi:softprob', @@ -280,7 +281,7 @@ def test_feature_weights(self): m.set_info(feature_weights=fw) np.testing.assert_allclose(fw, m.get_float_info('feature_weights')) # Handle empty - m.set_info(feature_weights=np.empty((0, 0))) + m.set_info(feature_weights=np.empty((0, ))) assert m.get_float_info('feature_weights').shape[0] == 0