Skip to content

Commit

Permalink
Implement typed and type erased tensor.
Browse files Browse the repository at this point in the history
* Use typed tensor for storing meta info like base margin.
* Extend the array interface handler to multi-dim.

Implement a general array view.

* Replace existing matrix and vector view.

lint.

Remove const too.

Doc/Test.

Include.

Use it in AUC.

Win build.

Use int32_t.

Use integral.

force the same type.

Use constexpr for old nvcc.

Test for empty tensor.

Rename to view.

Format.

Better document and perf.

Address reviewer's comment.

tidy.

Implement a general array view.

* Replace existing matrix and vector view.

lint.

Remove const too.

Doc/Test.

Include.

Use it in AUC.

Win build.

Use int32_t.

Use integral.

force the same type.

Use constexpr for old nvcc.

Test for empty tensor.

Rename to view.

Format.

Prototype.

Move string view.

Compile on CPU.

Some fixes for GPU compilation.

Array interface.

Use it in file iter.

Cleanup.

itemsize.

Documents.

Cache the pointer.

port.

cuda compilation.

Start working on ari.

Add clang-format config. (#7383)

Generated using `clang-format -style=google -dump-config > .clang-format`, with column
width changed from 80 to 100 to be consistent with existing cpplint check.

Define shape and stride.

Convert some uses.

Prototype for copy tensor info.

proto unravel

Indexer.

Unravel.

Cleanup.

Fixes.

fixe.s

WAR.

as column vector.

Convert some code.

some more.

some more.

Compile.

Ensure column vector from the beginning.

IO.

Add code comments.

Test for trivial dimension.

Start CPU implementation.

Refactor.

Dispatch.

Compile.

Cleanup slice and index for array interface.

Add document about handling user input.

Validate the shape of base margin.

Cleanup out prediction.

Use it in Python.

Optimization.

Cleanup.

Define unravel index as an interface.

Dispatch DType.

Err.

By pass the iterator.

Cleanup old code.

comment.

Cleanup.

Remove duplicated function.

Add contiguous.

Typo.

Fix copying for group.

comment.

Fix CSR.

Fix empty dimensions.

Use linalg for utilities.

Fix test.

Fix streams.

Basic elementwise kernel.

Fixes.

fix dtype.

Fix index.

Comment.

Cleanup.

popcnt implementation.

Move to compile time.

Fix.

Fix.

Move to compile time.

Forward.

Forward.

Include.

Lint.

Reintroduce the checks.

Fix long double.

Remove check.

Comment.

Remove macro.

Some changes in jvm package.

Include.

Ignore 0 shape.

Fix.

Restore bound check for now.

Fix slice.

Fix test feature weights.

Stricter requirements.

Reshape.

Require v3 interface.

Use CUDA ptr attr.

Fix test.

Invalid stream.

tidy.

Fix versions.

test stack.

CPU only.

Simplifies.

Reverse version.

Force c++14 for msvc.

Lint.

Remove imports.

Revert cmake changes.

Reset.

Check.

Unify device initialization.

Tests.

Just bypass the error.

Fix.

Fix.

test qid.

Fix test.

Cleanup.

lint.

Fix.

Restore the heuristic.

Update metainfo binary.

Fix typo.

Restore the optimization.

Cleanup sklearn tests.

Cleanup dask test.

Tidy.

Tidy is happy.

Polish.

Unittest.

Typo.
  • Loading branch information
trivialfis committed Nov 14, 2021
1 parent d4274bc commit b1b1cde
Show file tree
Hide file tree
Showing 36 changed files with 1,120 additions and 673 deletions.
46 changes: 46 additions & 0 deletions doc/contrib/coding_guide.rst
Expand Up @@ -134,3 +134,49 @@ Similarly, if you want to exclude C++ source from linting:
cd /path/to/xgboost/
python3 tests/ci_build/tidy.py --cpp=0
**********************************
Guide for handling user input data
**********************************

This is an in-comprehensive guide for handling user input data. XGBoost has wide verity
of native supported data structures, mostly come from higher level language bindings. The
inputs ranges from basic contiguous 1 dimension memory buffer to more sophisticated data
structures like columnar data with validity mask. Raw input data can be used in 2 places,
firstly it's the construction of various ``DMatrix``, secondly it's the in-place
prediction. For plain memory buffer, there's not much to discuss since it's just a
pointer with a size. But for general n-dimension array and columnar data, there are many
subtleties. XGBoost has 3 different data structures for handling optionally masked arrays
(tensors), for consuming user inputs ``ArrayInterface`` should be chosen. There are many
existing functions that accept only plain pointer due to legacy reasons (XGBoost started
as a much simpler library and didn't care about memory usage that much back then). The
``ArrayInterface`` is a in memory representation of ``__array_interface__`` protocol
defined by numpy or the ``__cuda_array_interface__`` defined by numba. Following is a
check list of things to have in mind when accepting related user inputs:

- [ ] Is it strided? (identified by the ``strides`` field)
- [ ] If it's a vector, is it row vector or column vector? (Identified by both ``shape``
and ``strides``).
- [ ] Is the data type supported? Half type and 128 integer types should be converted
before going into XGBoost.
- [ ] Does it have higher than 1 dimension? (identified by ``shape`` field)
- [ ] Are some of dimensions trivial? (shape[dim] <= 1)
- [ ] Does it have mask? (identified by ``mask`` field)
- [ ] Can the mask be broadcasted? (unsupported at the moment)
- [ ] Is it on CUDA memory? (identified by ``data`` field, and optionally ``stream``)

Most of the checks are handled by the ``ArrayInterface`` during construction, except for
the data type issue since it doesn't know how to cast such pointers with C builtin types.
But for safety reason one should still try to write related tests for the all items. The
data type issue should be taken care of in language binding for each of the specific data
input. For single-chunk columnar format, it's just a masked array for each column so it
should be treated uniformly as normal array. For input predictor ``X``, we have adapters
for each type of input. Some are composition of the others. For instance, CSR matrix has 3
potentially strided arrays for ``indptr``, ``indices`` and ``values``. No assumption
should be made to these components (all the check boxes should be considered). Slicing row
of CSR matrix should calculate the offset of each field based on respective strides.

For meta info like labels, which is growing both in size and complexity, we accept only
masked array at the moment (no specialized adapter). One should be careful about the
input data shape. For base margin it can be 2 dim or higher if we have multiple targets in
the future. The getters in ``DMatrix`` returns only 1 dimension flatten vectors at the
moment, which can be improved in the future when it's needed.
31 changes: 16 additions & 15 deletions include/xgboost/c_api.h
Expand Up @@ -249,7 +249,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
char const* json_config,
DMatrixHandle *out);

/*
/**
* ========================== Begin data callback APIs =========================
*
* Short notes for data callback
Expand All @@ -258,9 +258,9 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
* used by JVM packages. It uses `XGBoostBatchCSR` to accept batches for CSR formated
* input, and concatenate them into 1 final big CSR. The related functions are:
*
* - XGBCallbackSetData
* - XGBCallbackDataIterNext
* - XGDMatrixCreateFromDataIter
* - \ref XGBCallbackSetData
* - \ref XGBCallbackDataIterNext
* - \ref XGDMatrixCreateFromDataIter
*
* Another set is used by external data iterator. It accept foreign data iterators as
* callbacks. There are 2 different senarios where users might want to pass in callbacks
Expand All @@ -276,17 +276,17 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
* Related functions are:
*
* # Factory functions
* - `XGDMatrixCreateFromCallback` for external memory
* - `XGDeviceQuantileDMatrixCreateFromCallback` for quantile DMatrix
* - \ref XGDMatrixCreateFromCallback for external memory
* - \ref XGDeviceQuantileDMatrixCreateFromCallback for quantile DMatrix
*
* # Proxy that callers can use to pass data to XGBoost
* - XGProxyDMatrixCreate
* - XGDMatrixCallbackNext
* - DataIterResetCallback
* - XGProxyDMatrixSetDataCudaArrayInterface
* - XGProxyDMatrixSetDataCudaColumnar
* - XGProxyDMatrixSetDataDense
* - XGProxyDMatrixSetDataCSR
* - \ref XGProxyDMatrixCreate
* - \ref XGDMatrixCallbackNext
* - \ref DataIterResetCallback
* - \ref XGProxyDMatrixSetDataCudaArrayInterface
* - \ref XGProxyDMatrixSetDataCudaColumnar
* - \ref XGProxyDMatrixSetDataDense
* - \ref XGProxyDMatrixSetDataCSR
* - ... (data setters)
*/

Expand Down Expand Up @@ -411,7 +411,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
* - nthread (optional): Number of threads used for initializing DMatrix.
*
* \param out The created external memory DMatrix
* \param[out] out The created external memory DMatrix
*
* \return 0 when success, -1 when failure happens
*/
Expand Down Expand Up @@ -605,7 +605,8 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
* char const* feat_names [] {"feat_0", "feat_1"};
* XGDMatrixSetStrFeatureInfo(handle, "feature_name", feat_names, 2);
*
* // i for integer, q for quantitive. Similarly "int" and "float" are also recognized.
* // i for integer, q for quantitive, c for categorical. Similarly "int" and "float"
* // are also recognized.
* char const* feat_types [] {"i", "q"};
* XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, 2);
*
Expand Down
18 changes: 12 additions & 6 deletions include/xgboost/data.h
Expand Up @@ -11,12 +11,14 @@
#include <dmlc/data.h>
#include <dmlc/serializer.h>
#include <xgboost/base.h>
#include <xgboost/span.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h>
#include <xgboost/span.h>
#include <xgboost/string_view.h>

#include <algorithm>
#include <memory>
#include <numeric>
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -45,7 +47,7 @@ enum class FeatureType : uint8_t {
class MetaInfo {
public:
/*! \brief number of data fields in MetaInfo */
static constexpr uint64_t kNumField = 11;
static constexpr uint64_t kNumField = 12;

/*! \brief number of rows in the data */
uint64_t num_row_{0}; // NOLINT
Expand All @@ -67,7 +69,7 @@ class MetaInfo {
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from.
*/
HostDeviceVector<bst_float> base_margin_; // NOLINT
linalg::Tensor<float, 3> base_margin_; // NOLINT
/*!
* \brief lower bound of the label, to be used for survival analysis (censored regression)
*/
Expand Down Expand Up @@ -157,7 +159,8 @@ class MetaInfo {
*
* Right now only 1 column is permitted.
*/
void SetInfo(const char* key, std::string const& interface_str);

void SetInfo(StringView key, StringView interface_str);

void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const;
Expand All @@ -179,6 +182,9 @@ class MetaInfo {
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);

private:
void SetInfoFromHost(StringView key, Json arr);
void SetInfoFromCUDA(StringView key, Json arr);

/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
};
Expand Down Expand Up @@ -477,7 +483,7 @@ class DMatrix {
this->Info().SetInfo(key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) {
this->Info().SetInfo(key, interface_str);
this->Info().SetInfo(key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/intrusive_ptr.h
Expand Up @@ -19,7 +19,7 @@ namespace xgboost {
*/
class IntrusivePtrCell {
private:
std::atomic<int32_t> count_;
std::atomic<int32_t> count_ {0};
template <typename T> friend class IntrusivePtr;

std::int32_t IncRef() noexcept {
Expand All @@ -31,7 +31,7 @@ class IntrusivePtrCell {
bool IsZero() const { return Count() == 0; }

public:
IntrusivePtrCell() noexcept : count_{0} {}
IntrusivePtrCell() noexcept = default;
int32_t Count() const { return count_.load(std::memory_order_relaxed); }
};

Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/predictor.h
Expand Up @@ -128,7 +128,7 @@ class Predictor {
*/
virtual void InitOutPredictions(const MetaInfo &info,
HostDeviceVector<bst_float> *out_predt,
const gbm::GBTreeModel &model) const = 0;
const gbm::GBTreeModel &model) const;

/**
* \brief Generate batch predictions for a given feature matrix. May use
Expand Down
30 changes: 14 additions & 16 deletions jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu
Expand Up @@ -35,13 +35,12 @@ template <typename T> T CheckJvmCall(T const &v, JNIEnv *jenv) {
}

template <typename VCont>
void CopyColumnMask(xgboost::ArrayInterface const &interface,
void CopyColumnMask(xgboost::ArrayInterface<1> const &interface,
std::vector<Json> const &columns, cudaMemcpyKind kind,
size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) {
auto &mask = *p_mask;
auto &out = *p_out;
auto size = sizeof(typename VCont::value_type) * interface.num_rows *
interface.num_cols;
auto size = sizeof(typename VCont::value_type) * interface.n;
mask.resize(size);
CHECK(RawPtr(mask));
CHECK(size);
Expand All @@ -67,11 +66,11 @@ void CopyColumnMask(xgboost::ArrayInterface const &interface,
LOG(FATAL) << "Invalid shape of mask";
}
out["mask"]["typestr"] = String("<t1");
out["mask"]["version"] = Integer(1);
out["mask"]["version"] = Integer(3);
}

template <typename DCont, typename VCont>
void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
std::vector<Json> const &columns, cudaMemcpyKind kind,
std::vector<DCont> *p_data, std::vector<VCont> *p_mask,
std::vector<xgboost::Json> *p_out, cudaStream_t stream) {
Expand All @@ -81,7 +80,7 @@ void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
for (size_t c = 0; c < interface_arr.size(); ++c) {
auto &interface = interface_arr.at(c);
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows * interface.num_cols;
size_t size = element_size * interface.n;

auto &data = (*p_data)[c];
auto &mask = (*p_mask)[c];
Expand All @@ -95,25 +94,24 @@ void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
Json{Boolean{false}}};

out["data"] = Array(std::move(j_data));
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.num_rows)),
Json(Integer(interface.num_cols))});
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.Shape(0)))});

if (interface.valid.Data()) {
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
}
out["typestr"] = String("<f4");
out["version"] = Integer(1);
out["version"] = Integer(3);
}
}

void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
auto &j_interface = *p_interface;
CHECK_EQ(get<Array const>(j_interface).size(), 1);
auto object = get<Object>(get<Array>(j_interface)[0]);
ArrayInterface interface(object);
out->resize(interface.num_rows);
ArrayInterface<1> interface(object);
out->resize(interface.Shape(0));
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows;
size_t size = element_size * interface.n;
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
cudaMemcpyDeviceToDevice, stream));
j_interface[0]["data"][0] = reinterpret_cast<Integer::Int>(RawPtr(*out));
Expand Down Expand Up @@ -285,11 +283,11 @@ class DataIteratorProxy {

Json features = json_interface["features_str"];
auto json_columns = get<Array const>(features);
std::vector<ArrayInterface> interfaces;
std::vector<ArrayInterface<1>> interfaces;

// Stage the data
for (auto &json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
auto column = ArrayInterface<1>(get<Object const>(json_col));
interfaces.emplace_back(column);
}
Json::Dump(features, &interface_str);
Expand Down Expand Up @@ -342,9 +340,9 @@ class DataIteratorProxy {
// Data
auto const &json_interface = host_columns_.at(it_)->interfaces;

std::vector<ArrayInterface> in;
std::vector<ArrayInterface<1>> in;
for (auto interface : json_interface) {
auto column = ArrayInterface(get<Object const>(interface));
auto column = ArrayInterface<1>(get<Object const>(interface));
in.emplace_back(column);
}
std::vector<Json> out;
Expand Down
34 changes: 15 additions & 19 deletions python-package/xgboost/data.py
Expand Up @@ -5,7 +5,7 @@
import json
import warnings
import os
from typing import Any, Tuple, Callable, Optional, List
from typing import Any, Tuple, Callable, Optional, List, Union

import numpy as np

Expand Down Expand Up @@ -138,14 +138,14 @@ def _is_numpy_array(data):
return isinstance(data, (np.ndarray, np.matrix))


def _ensure_np_dtype(data, dtype):
def _ensure_np_dtype(data, dtype) -> Tuple[np.ndarray, np.dtype]:
if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]:
data = data.astype(np.float32, copy=False)
dtype = np.float32
return data, dtype


def _maybe_np_slice(data, dtype):
def _maybe_np_slice(data: np.ndarray, dtype) -> np.ndarray:
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
Expand Down Expand Up @@ -852,23 +852,17 @@ def _validate_meta_shape(data: Any, name: str) -> None:


def _meta_from_numpy(
data: np.ndarray, field: str, dtype, handle: ctypes.c_void_p
data: np.ndarray,
field: str,
dtype: Optional[Union[np.dtype, str]],
handle: ctypes.c_void_p,
) -> None:
data = _maybe_np_slice(data, dtype)
data, dtype = _ensure_np_dtype(data, dtype)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.size

c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
ptr,
c_bst_ulong(size),
c_type
))
if interface.get("mask", None) is not None:
raise ValueError("Masked array is not supported.")
interface_str = _array_interface(data)
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface_str))


def _meta_from_list(data, field, dtype, handle):
Expand Down Expand Up @@ -911,7 +905,9 @@ def _meta_from_dt(data, field: str, dtype, handle: ctypes.c_void_p):
_meta_from_numpy(data, field, dtype, handle)


def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
def dispatch_meta_backend(
matrix: DMatrix, data, name: str, dtype: Optional[Union[str, np.dtype]] = None
):
'''Dispatch for meta info.'''
handle = matrix.handle
assert handle is not None
Expand Down
3 changes: 2 additions & 1 deletion src/common/common.cu
Expand Up @@ -12,7 +12,8 @@ int AllVisibleGPUs() {
// When compiled with CUDA but running on CPU only device,
// cudaGetDeviceCount will fail.
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
} catch(const dmlc::Error &except) {
} catch (const dmlc::Error &) {
cudaGetLastError(); // reset error.
return 0;
}
return n_visgpus;
Expand Down

0 comments on commit b1b1cde

Please sign in to comment.