From 14463467803d79c313fb156db0af72695bf2bb62 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 30 Jul 2020 16:43:08 +0800 Subject: [PATCH 1/9] Implement feature weights for column sampling. --- amalgamation/xgboost-all0.cc | 1 + demo/guide-python/feature_weights.py | 49 +++++++++++++++++++ demo/json-model/json_parser.py | 2 +- doc/parameter.rst | 3 ++ include/xgboost/c_api.h | 43 +++++++++++++++++ include/xgboost/data.h | 32 ++++--------- python-package/xgboost/core.py | 39 +++++++++++++++ python-package/xgboost/data.py | 72 ++++++++++++++++++---------- python-package/xgboost/sklearn.py | 27 ++++++++--- src/c_api/c_api.cc | 31 ++++++++++++ src/common/common.h | 12 +++++ src/common/random.cc | 38 +++++++++++++++ src/common/random.h | 58 +++++++++++++--------- src/data/data.cc | 37 ++++++++++++++ src/tree/updater_colmaker.cc | 6 ++- src/tree/updater_gpu_hist.cu | 6 ++- src/tree/updater_quantile_hist.cc | 10 ++-- tests/cpp/common/test_common.cc | 13 +++++ tests/cpp/common/test_random.cc | 68 ++++++++++++++++++++++---- tests/cpp/tree/test_gpu_hist.cu | 9 ++-- tests/python/test_demos.py | 16 ++++--- tests/python/test_dmatrix.py | 34 +++++++++++-- tests/python/test_with_sklearn.py | 61 +++++++++++++++++++++++ 23 files changed, 556 insertions(+), 111 deletions(-) create mode 100644 demo/guide-python/feature_weights.py create mode 100644 src/common/random.cc create mode 100644 tests/cpp/common/test_common.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 8220135d9e32..792b43797ce5 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -68,6 +68,7 @@ #include "../src/learner.cc" #include "../src/logging.cc" #include "../src/common/common.cc" +#include "../src/common/random.cc" #include "../src/common/charconv.cc" #include "../src/common/timer.cc" #include "../src/common/quantile.cc" diff --git a/demo/guide-python/feature_weights.py b/demo/guide-python/feature_weights.py new file mode 100644 index 000000000000..07a8719422c6 --- /dev/null +++ b/demo/guide-python/feature_weights.py @@ -0,0 +1,49 @@ +'''Using feature weight to change column sampling. + + .. versionadded:: 1.3.0 +''' + +import numpy as np +import xgboost +from matplotlib import pyplot as plt +import argparse + + +def main(args): + rng = np.random.RandomState(1994) + + kRows = 1000 + kCols = 10 + + X = rng.randn(kRows, kCols) + y = rng.randn(kRows) + fw = np.ones(shape=(kCols,)) + for i in range(kCols): + fw[i] *= float(i) + + dtrain = xgboost.DMatrix(X, y) + dtrain.feature_weights = fw + + bst = xgboost.train({'tree_method': 'hist', + 'colsample_bynode': 0.5}, + dtrain, num_boost_round=10, + evals=[(dtrain, 'd')]) + featue_map = bst.get_fscore() + # feature zero has 0 weight + assert featue_map.get('f0', None) is None + assert max(featue_map.values()) == featue_map.get('f9') + + if args.plot: + xgboost.plot_importance(bst) + plt.show() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--plot', + type=int, + default=1, + help='Set to 0 to disable plotting the evaluation history.') + args = parser.parse_args() + main(args) diff --git a/demo/json-model/json_parser.py b/demo/json-model/json_parser.py index eedcbf9c2287..c41a44d881c8 100644 --- a/demo/json-model/json_parser.py +++ b/demo/json-model/json_parser.py @@ -94,7 +94,7 @@ def __str__(self): class Model: '''Gradient boosted tree model.''' - def __init__(self, m: dict): + def __init__(self, model: dict): '''Construct the Model from JSON object. parameters diff --git a/doc/parameter.rst b/doc/parameter.rst index 626ddf10f8ab..685bddbc815d 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -107,6 +107,9 @@ Parameters for Tree Booster 'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at each split. + On Python interface, one can set the ``feature_weights`` for DMatrix to define the + probability of each feature being selected when using column sampling. + * ``lambda`` [default=1, alias: ``reg_lambda``] - L2 regularization term on weights. Increasing this value will make model more conservative. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 794cbdf19e8f..ddb01e7d98db 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -483,6 +483,49 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, bst_ulong *size, const char ***out_features); +/*! + * \brief Set feature info that's not strings. Currently accepted fields are: + * + * - feature_weight + * + * \param handle An instance of data matrix + * \param field Feild name + * \param data Pointer to consecutive memory storing data. + * \param size Size of the data, this is relative to size of type. (Meaning NOT number + * of bytes.) + * \param type Indicator of data type. This is defined in xgboost::DataType enum class. + * + * float = 1 + * double = 2 + * uint32_t = 3 + * uint64_t = 4 + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, + void *data, bst_ulong size, int type); + +/*! + * \brief Get feature info in a thread local buffer. + * + * Caller is responsible for copying out the data, before next call to any API function of + * XGBoost. The data is always on CPU thread local storage. + * + * \param handle An instance of data matrix. + * \param field Field name. + * \param out_type Type of this field. This is defined in xgboost::DataType enum class. + * \param out_size Length of output data, this is relative to size of out_type. (Meaning + * NOT number of bytes.) + * \param out_dptr Pointer to output buffer. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixGetFeatureInfo(DMatrixHandle handle, + const char* field, + int* out_type, + bst_ulong* out_size, + const void** out_dptr); + /*! * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix * \param handle a instance of data matrix diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 1ee292a89edb..fbd3da5c84b7 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -88,34 +88,17 @@ class MetaInfo { * \brief Type of each feature. Automatically set when feature_type_names is specifed. */ HostDeviceVector feature_types; + /* + * \brief Weight of each feature, used to define the probability of each feature being + * selected when using column sampling. + */ + HostDeviceVector feature_weigths; /*! \brief default constructor */ MetaInfo() = default; MetaInfo(MetaInfo&& that) = default; MetaInfo& operator=(MetaInfo&& that) = default; - MetaInfo& operator=(MetaInfo const& that) { - this->num_row_ = that.num_row_; - this->num_col_ = that.num_col_; - this->num_nonzero_ = that.num_nonzero_; - - this->labels_.Resize(that.labels_.Size()); - this->labels_.Copy(that.labels_); - - this->group_ptr_ = that.group_ptr_; - - this->weights_.Resize(that.weights_.Size()); - this->weights_.Copy(that.weights_); - - this->base_margin_.Resize(that.base_margin_.Size()); - this->base_margin_.Copy(that.base_margin_); - - this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size()); - this->labels_lower_bound_.Copy(that.labels_lower_bound_); - - this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size()); - this->labels_upper_bound_.Copy(that.labels_upper_bound_); - return *this; - } + MetaInfo& operator=(MetaInfo const& that) = delete; /*! * \brief Validate all metainfo. @@ -180,6 +163,9 @@ class MetaInfo { void SetFeatureInfo(const char *key, const char **info, const bst_ulong size); void GetFeatureInfo(const char *field, std::vector* out_str_vecs) const; + void SetFeatureInfo(const char *field, const void *info, DataType type, + bst_ulong size); + void GetFeatureInfo(const char *field, DataType *out_type, std::vector* out) const; /* * \brief Extend with other MetaInfo. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4bc77783ee91..294a0ed5a482 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -841,6 +841,45 @@ def feature_types(self, feature_types): None, c_bst_ulong(0))) + @property + def feature_weights(self): + '''Weight for each feature, defines the probability of each feature + being selected when colsample is being used. All values must + be greater than 0, otherwise a `XGBoostError` is thrown. + + .. versionadded:: 1.3.0 + + ''' + length = c_bst_ulong() + ret = ctypes.POINTER(ctypes.c_void_p)() + out_type = ctypes.c_int() + _check_call(_LIB.XGDMatrixGetFeatureInfo( + self.handle, + c_str('feature_weight'), + ctypes.byref(out_type), + ctypes.byref(length), + ctypes.byref(ret) + )) + to_data_type = {1: np.float32, 2: np.float64, 3: np.uint32, + 4: np.uint64} + to_c_type = {1: ctypes.c_float, 2: ctypes.c_double, 3: ctypes.c_uint32, + 4: ctypes.c_uint64} + dtype = to_data_type[out_type.value] + ptr = ctypes.cast(ret, ctypes.POINTER(to_c_type[out_type.value])) + return ctypes2numpy(ptr, length.value, dtype) + + @feature_weights.setter + def feature_weights(self, array): + '''Setter for feature weights. Clear the feature weights if array is + None. + + ''' + from .data import dispatch_meta_backend + if array is None: + array = np.empty((0, 0)) + dispatch_meta_backend(matrix=self, data=array, name='feature_weight', + is_feature=True) + class DeviceQuantileDMatrix(DMatrix): """Device memory Data Matrix used in XGBoost for training with diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 9491efd1c38c..65f7e179a3a7 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -530,31 +530,52 @@ def dispatch_data_backend(data, missing, threads, raise TypeError('Not supported type for data.' + str(type(data))) -def _meta_from_numpy(data, field, dtype, handle): - data = _maybe_np_slice(data, dtype) - if dtype == 'uint32': - c_data = c_array(ctypes.c_uint32, data) - _check_call(_LIB.XGDMatrixSetUIntInfo(handle, - c_str(field), - c_array(ctypes.c_uint, data), - c_bst_ulong(len(data)))) - elif dtype == 'float': - c_data = c_array(ctypes.c_float, data) - _check_call(_LIB.XGDMatrixSetFloatInfo(handle, - c_str(field), - c_data, - c_bst_ulong(len(data)))) +def _to_data_type(dtype: str): + dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4} + return dtype_map[dtype] + + +def _meta_from_numpy(data, field, dtype, handle, is_feature: bool = False): + if is_feature: + data = _maybe_np_slice(data, dtype) + interface = data.__array_interface__ + assert interface.get('mask', None) is None + size = data.shape[0] + c_type = _to_data_type(str(data.dtype)) + data = interface['data'] + data = ctypes.c_void_p(data[0]) + _check_call(_LIB.XGDMatrixSetFeatureInfo( + handle, + c_str(field), + data, + size, + c_type + )) else: - raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field) + data = _maybe_np_slice(data, dtype) + if dtype == 'uint32': + c_data = c_array(ctypes.c_uint32, data) + _check_call(_LIB.XGDMatrixSetUIntInfo(handle, + c_str(field), + c_array(ctypes.c_uint, data), + c_bst_ulong(len(data)))) + elif dtype == 'float': + c_data = c_array(ctypes.c_float, data) + _check_call(_LIB.XGDMatrixSetFloatInfo(handle, + c_str(field), + c_data, + c_bst_ulong(len(data)))) + else: + raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field) -def _meta_from_list(data, field, dtype, handle): +def _meta_from_list(data, field, dtype, handle, is_feature): data = np.array(data) - _meta_from_numpy(data, field, dtype, handle) + _meta_from_numpy(data, field, dtype, handle, is_feature) -def _meta_from_tuple(data, field, dtype, handle): - return _meta_from_list(data, field, dtype, handle) +def _meta_from_tuple(data, field, dtype, handle, is_feature): + return _meta_from_list(data, field, dtype, handle, is_feature) def _meta_from_cudf_df(data, field, handle): @@ -592,28 +613,29 @@ def _meta_from_dt(data, field, dtype, handle): _meta_from_numpy(data, field, dtype, handle) -def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): +def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None, + is_feature: bool = False): '''Dispatch for meta info.''' handle = matrix.handle if data is None: return if _is_list(data): - _meta_from_list(data, name, dtype, handle) + _meta_from_list(data, name, dtype, handle, is_feature) return if _is_tuple(data): - _meta_from_tuple(data, name, dtype, handle) + _meta_from_tuple(data, name, dtype, handle, is_feature) return if _is_numpy_array(data): - _meta_from_numpy(data, name, dtype, handle) + _meta_from_numpy(data, name, dtype, handle, is_feature) return if _is_pandas_df(data): data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) - _meta_from_numpy(data, name, dtype, handle) + _meta_from_numpy(data, name, dtype, handle, is_feature) return if _is_pandas_series(data): data = data.values.astype('float') assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 - _meta_from_numpy(data, name, dtype, handle) + _meta_from_numpy(data, name, dtype, handle, is_feature) return if _is_dlpack(data): data = _transform_dlpack(data) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index f533f7f3477d..6c666961f432 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -441,6 +441,7 @@ def load_model(self, fname): def fit(self, X, y, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model=None, sample_weight_eval_set=None, + feature_weights=None, callbacks=None): # pylint: disable=invalid-name,attribute-defined-outside-init """Fit gradient boosting model @@ -459,9 +460,6 @@ def fit(self, X, y, sample_weight=None, base_margin=None, A list of (X, y) tuple pairs to use as validation sets, for which metrics will be computed. Validation metrics will help us track the performance of the model. - sample_weight_eval_set : list, optional - A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of - instance weights on the i-th validation set. eval_metric : str, list of str, or callable, optional If a str, should be a built-in evaluation metric to use. See doc/parameter.rst. @@ -490,6 +488,13 @@ def fit(self, X, y, sample_weight=None, base_margin=None, xgb_model : str file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). + sample_weight_eval_set : list, optional + A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of + instance weights on the i-th validation set. + feature_weights: array_like + Weight for each feature, defines the probability of each feature + being selected when colsample is being used. All values must be + greater than 0, otherwise a `ValueError` is thrown. callbacks : list of callback functions List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using :ref:`callback_api`. @@ -498,6 +503,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None, .. code-block:: python [xgb.callback.reset_learning_rate(custom_rates)] + """ self.n_features_in_ = X.shape[1] @@ -505,6 +511,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) + train_dmatrix.feature_weights = feature_weights evals_result = {} @@ -759,7 +766,7 @@ def __init__(self, objective="binary:logistic", **kwargs): def fit(self, X, y, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model=None, - sample_weight_eval_set=None, callbacks=None): + sample_weight_eval_set=None, feature_weights=None, callbacks=None): # pylint: disable = attribute-defined-outside-init,arguments-differ evals_result = {} @@ -821,6 +828,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) + train_dmatrix.feature_weights = feature_weights self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(), @@ -1101,10 +1109,10 @@ def __init__(self, objective='rank:pairwise', **kwargs): raise ValueError("please use XGBRanker for ranking task") def fit(self, X, y, group, sample_weight=None, base_margin=None, - eval_set=None, - sample_weight_eval_set=None, eval_group=None, eval_metric=None, + eval_set=None, sample_weight_eval_set=None, + eval_group=None, eval_metric=None, early_stopping_rounds=None, verbose=False, xgb_model=None, - callbacks=None): + feature_weights=None, callbacks=None): # pylint: disable = attribute-defined-outside-init,arguments-differ """Fit gradient boosting ranker @@ -1170,6 +1178,10 @@ def fit(self, X, y, group, sample_weight=None, base_margin=None, xgb_model : str file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). + feature_weights: array_like + Weight for each feature, defines the probability of each feature + being selected when colsample is being used. All values must be + greater than 0, otherwise a `ValueError` is thrown. callbacks : list of callback functions List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using @@ -1205,6 +1217,7 @@ def _dmat_init(group, **params): train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) + train_dmatrix.feature_weights = feature_weights train_dmatrix.set_group(group) evals_result = {} diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index aa6ecf43a784..54e675bb1883 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -316,6 +316,37 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, API_END(); } +XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, + void *data, xgboost::bst_ulong size, + int type) { + API_BEGIN(); + CHECK_HANDLE(); + auto &info = static_cast *>(handle)->get()->Info(); + info.SetFeatureInfo(field, data, static_cast(type), size); + API_END(); +} + +XGB_DLL int XGDMatrixGetFeatureInfo(DMatrixHandle handle, + const char* field, + int* out_type, + xgboost::bst_ulong* out_len, + const void** out_dptr) { + API_BEGIN(); + CHECK_HANDLE(); + auto m = *static_cast*>(handle); + auto &info = static_cast *>(handle)->get()->Info(); + DataType out_t; + + // Right now only float for feature weights is used. If other data types are rquired, + // we can pass in a general memory buffer. + auto& float_vec = m->GetThreadLocal().ret_vec_float; + info.GetFeatureInfo(field, &out_t, &float_vec); + *out_type = static_cast(out_t); + *out_len = float_vec.size(); + *out_dptr = float_vec.data(); + API_END(); +} + XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned* group, xgboost::bst_ulong len) { diff --git a/src/common/common.h b/src/common/common.h index b0bd6b6d6cec..a4397d1c89aa 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -9,12 +9,15 @@ #include #include +#include #include +#include #include #include #include #include #include +#include #if defined(__CUDACC__) #include @@ -160,6 +163,15 @@ inline void AssertOneAPISupport() { #endif // XGBOOST_USE_ONEAPI } +template > +std::vector ArgSort(std::vector const &array, Comp comp = std::less{}) { + std::vector result(array.size()); + std::iota(result.begin(), result.end(), 0); + std::stable_sort( + result.begin(), result.end(), + [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); }); + return result; +} } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/random.cc b/src/common/random.cc new file mode 100644 index 000000000000..f386cad916b2 --- /dev/null +++ b/src/common/random.cc @@ -0,0 +1,38 @@ +/*! + * Copyright 2020 by XGBoost Contributors + * \file random.cc + */ +#include "random.h" + +namespace xgboost { +namespace common { +std::shared_ptr> ColumnSampler::ColSample( + std::shared_ptr> p_features, + float colsample) { + if (colsample == 1.0f) { + return p_features; + } + const auto &features = p_features->HostVector(); + CHECK_GT(features.size(), 0); + + int n = std::max(1, static_cast(colsample * features.size())); + auto p_new_features = std::make_shared>(); + auto &new_features = *p_new_features; + + if (feature_weights_.size() != 0) { + new_features.HostVector() = WeightedSamplingWithoutReplacement( + p_features->HostVector(), feature_weights_, n); + } else { + new_features.Resize(features.size()); + std::copy(features.begin(), features.end(), + new_features.HostVector().begin()); + std::shuffle(new_features.HostVector().begin(), + new_features.HostVector().end(), rng_); + new_features.Resize(n); + } + std::sort(new_features.HostVector().begin(), new_features.HostVector().end()); + return p_new_features; +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/random.h b/src/common/random.h index 45af80ce030b..3842c975c50a 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2020 by Contributors * \file random.h * \brief Utility related to random. * \author Tianqi Chen @@ -10,14 +10,17 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include "xgboost/host_device_vector.h" +#include "common.h" namespace xgboost { namespace common { @@ -75,6 +78,31 @@ using GlobalRandomEngine = RandomEngine; */ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) +template +std::vector WeightedSamplingWithoutReplacement( + std::vector const &array, std::vector const &weights, size_t n) { + // ES sampling. + CHECK_EQ(array.size(), weights.size()); + std::vector keys(weights.size()); + std::uniform_real_distribution dist; + auto& rng = GlobalRandom(); + for (size_t i = 0; i < array.size(); ++i) { + auto w = std::max(weights.at(i), kRtEps); + auto u = dist(rng); + auto k = std::log(u) / w; + keys[i] = k; + } + auto ind = ArgSort(keys, std::greater<>{}); + ind.resize(n); + + std::vector results(ind.size()); + for (size_t k = 0; k < ind.size(); ++k) { + auto idx = ind[k]; + results[k] = array[idx]; + } + return results; +} + /** * \class ColumnSampler * @@ -82,36 +110,18 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) * colsample_bynode parameters. Should be initialised before tree construction and to * reset when tree construction is completed. */ - class ColumnSampler { std::shared_ptr> feature_set_tree_; std::map>> feature_set_level_; + std::vector feature_weights_; float colsample_bylevel_{1.0f}; float colsample_bytree_{1.0f}; float colsample_bynode_{1.0f}; GlobalRandomEngine rng_; - std::shared_ptr> ColSample( - std::shared_ptr> p_features, float colsample) { - if (colsample == 1.0f) return p_features; - const auto& features = p_features->HostVector(); - CHECK_GT(features.size(), 0); - int n = std::max(1, static_cast(colsample * features.size())); - auto p_new_features = std::make_shared>(); - auto& new_features = *p_new_features; - new_features.Resize(features.size()); - std::copy(features.begin(), features.end(), - new_features.HostVector().begin()); - std::shuffle(new_features.HostVector().begin(), - new_features.HostVector().end(), rng_); - new_features.Resize(n); - std::sort(new_features.HostVector().begin(), - new_features.HostVector().end()); - - return p_new_features; - } - public: + std::shared_ptr> ColSample( + std::shared_ptr> p_features, float colsample); /** * \brief Column sampler constructor. * \note This constructor manually sets the rng seed @@ -139,8 +149,10 @@ class ColumnSampler { * \param colsample_bytree * \param skip_index_0 (Optional) True to skip index 0. */ - void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel, + void Init(int64_t num_col, std::vector feature_weights, + float colsample_bynode, float colsample_bylevel, float colsample_bytree, bool skip_index_0 = false) { + feature_weights_ = std::move(feature_weights); colsample_bylevel_ = colsample_bylevel; colsample_bytree_ = colsample_bytree; colsample_bynode_ = colsample_bynode; diff --git a/src/data/data.cc b/src/data/data.cc index 401a35081830..d446cc995a43 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { } else { out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs); } + + out.feature_weigths.Resize(this->feature_weigths.Size()); + out.feature_weigths.Copy(this->feature_weigths); return out; } @@ -452,6 +455,30 @@ void MetaInfo::GetFeatureInfo(const char *field, } } +void MetaInfo::SetFeatureInfo(const char *c_field, const void *info, DataType type, + bst_ulong size) { + std::string field {c_field}; + CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info."; + auto& h_feature_weights = feature_weigths.HostVector(); + h_feature_weights.resize(size); + DISPATCH_CONST_PTR( + type, info, cast_dptr, + std::copy(cast_dptr, cast_dptr + size, h_feature_weights.begin())); + bool valid = std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(), + [](float w) { return w >= 0; }); + CHECK(valid) << "Feature weight must be greater than 0."; +} + +void MetaInfo::GetFeatureInfo(const char *c_field, DataType *out_type, + std::vector *out) const { + std::string field {c_field}; + CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info."; + *out_type = DataType::kFloat32; + out->resize(feature_weigths.Size()); + auto const& h_feature_weights = feature_weigths.ConstHostVector(); + std::copy(h_feature_weights.cbegin(), h_feature_weights.cend(), out->begin()); +} + void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { if (accumulate_rows) { this->num_row_ += that.num_row_; @@ -497,6 +524,11 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { auto &h_feature_types = feature_types.HostVector(); LoadFeatureType(this->feature_type_names, &h_feature_types); } + if (!that.feature_weigths.Empty()) { + this->feature_weigths.Resize(that.feature_weigths.Size()); + this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx()); + this->feature_weigths.Copy(that.feature_weigths); + } } void MetaInfo::Validate(int32_t device) const { @@ -538,6 +570,11 @@ void MetaInfo::Validate(int32_t device) const { check_device(labels_lower_bound_); return; } + if (feature_weigths.Size() != 0) { + CHECK_EQ(feature_weigths.Size(), num_col_) + << "Size of feature_weights must equal to number of columns."; + check_device(feature_weigths); + } if (labels_upper_bound_.Size() != 0) { CHECK_EQ(labels_upper_bound_.Size(), num_row_) << "Size of label_upper_bound must equal to number of rows."; diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 951cfdb5ec27..45cdb0ba9163 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -235,8 +235,10 @@ class ColMaker: public TreeUpdater { } } { - column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode, - param_.colsample_bylevel, param_.colsample_bytree); + column_sampler_.Init(fmat.Info().num_col_, + fmat.Info().feature_weigths.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree); } { // setup temp space for each thread diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5cbe75350402..3535a59d6f85 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -266,8 +266,10 @@ struct GPUHistMakerDevice { // Note that the column sampler must be passed by value because it is not // thread safe void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { - this->column_sampler.Init(num_columns, param.colsample_bynode, - param.colsample_bylevel, param.colsample_bytree); + auto const& info = dmat->Info(); + this->column_sampler.Init(num_columns, info.feature_weigths.HostVector(), + param.colsample_bynode, param.colsample_bylevel, + param.colsample_bytree); dh::safe_cuda(cudaSetDevice(device_id)); this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 37a90dfebd74..95d3c2008ef9 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -841,11 +841,13 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& // store a pointer to the tree p_last_tree_ = &tree; if (data_layout_ == kDenseDataOneBased) { - column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel, - param_.colsample_bytree, true); + column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree, true); } else { - column_sampler_.Init(info.num_col_, param_.colsample_bynode, param_.colsample_bylevel, - param_.colsample_bytree, false); + column_sampler_.Init(info.num_col_, info.feature_weigths.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree, false); } if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { /* specialized code for dense data: diff --git a/tests/cpp/common/test_common.cc b/tests/cpp/common/test_common.cc new file mode 100644 index 000000000000..006860b11af2 --- /dev/null +++ b/tests/cpp/common/test_common.cc @@ -0,0 +1,13 @@ +#include +#include "../../../src/common/common.h" + +namespace xgboost { +namespace common { +TEST(ArgSort, Basic) { + std::vector inputs {3.0, 2.0, 1.0}; + auto ret = ArgSort(inputs); + std::vector sol{2, 1, 0}; + ASSERT_EQ(ret, sol); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_random.cc b/tests/cpp/common/test_random.cc index dc7b38554162..9b2a1515543f 100644 --- a/tests/cpp/common/test_random.cc +++ b/tests/cpp/common/test_random.cc @@ -8,9 +8,10 @@ namespace common { TEST(ColumnSampler, Test) { int n = 128; ColumnSampler cs; + std::vector feature_weights; // No node sampling - cs.Init(n, 1.0f, 0.5f, 0.5f); + cs.Init(n, feature_weights, 1.0f, 0.5f, 0.5f); auto set0 = cs.GetFeatureSet(0); ASSERT_EQ(set0->Size(), 32); @@ -23,7 +24,7 @@ TEST(ColumnSampler, Test) { ASSERT_EQ(set2->Size(), 32); // Node sampling - cs.Init(n, 0.5f, 1.0f, 0.5f); + cs.Init(n, feature_weights, 0.5f, 1.0f, 0.5f); auto set3 = cs.GetFeatureSet(0); ASSERT_EQ(set3->Size(), 32); @@ -33,19 +34,19 @@ TEST(ColumnSampler, Test) { ASSERT_EQ(set4->Size(), 32); // No level or node sampling, should be the same at different depth - cs.Init(n, 1.0f, 1.0f, 0.5f); + cs.Init(n, feature_weights, 1.0f, 1.0f, 0.5f); ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector()); - cs.Init(n, 1.0f, 1.0f, 1.0f); + cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f); auto set5 = cs.GetFeatureSet(0); ASSERT_EQ(set5->Size(), n); - cs.Init(n, 1.0f, 1.0f, 1.0f); + cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f); auto set6 = cs.GetFeatureSet(0); ASSERT_EQ(set5->HostVector(), set6->HostVector()); // Should always be a minimum of one feature - cs.Init(n, 1e-16f, 1e-16f, 1e-16f); + cs.Init(n, feature_weights, 1e-16f, 1e-16f, 1e-16f); ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1); } @@ -56,13 +57,13 @@ TEST(ColumnSampler, ThreadSynchronisation) { size_t iterations = 10; size_t levels = 5; std::vector reference_result; - bool success = - true; // Cannot use google test asserts in multithreaded region + std::vector feature_weights; + bool success = true; // Cannot use google test asserts in multithreaded region #pragma omp parallel num_threads(num_threads) { for (auto j = 0ull; j < iterations; j++) { ColumnSampler cs(j); - cs.Init(n, 0.5f, 0.5f, 0.5f); + cs.Init(n, feature_weights, 0.5f, 0.5f, 0.5f); for (auto level = 0ull; level < levels; level++) { auto result = cs.GetFeatureSet(level)->ConstHostVector(); #pragma omp single @@ -76,5 +77,54 @@ TEST(ColumnSampler, ThreadSynchronisation) { } ASSERT_TRUE(success); } + +TEST(ColumnSampler, WeightedSampling) { + auto test_basic = [](int first) { + std::vector feature_weights(2); + feature_weights[0] = std::abs(first - 1.0f); + feature_weights[1] = first - 0.0f; + ColumnSampler cs{0}; + cs.Init(2, feature_weights, 1.0, 1.0, 0.5); + auto feature_sets = cs.GetFeatureSet(0); + auto const &h_feat_set = feature_sets->HostVector(); + ASSERT_EQ(h_feat_set.size(), 1); + ASSERT_EQ(h_feat_set[0], first - 0); + }; + + test_basic(0); + test_basic(1); + + size_t constexpr kCols = 64; + std::vector feature_weights(kCols); + SimpleLCG rng; + SimpleRealUniformDistribution dist(.0f, 12.0f); + std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); }); + ColumnSampler cs{0}; + cs.Init(kCols, feature_weights, 0.5f, 1.0f, 1.0f); + std::vector features(kCols); + std::iota(features.begin(), features.end(), 0); + std::vector freq(kCols, 0); + for (size_t i = 0; i < 1024; ++i) { + auto fset = cs.GetFeatureSet(0); + ASSERT_EQ(kCols * 0.5, fset->Size()); + auto const& h_fset = fset->HostVector(); + for (auto f : h_fset) { + freq[f] += 1.0f; + } + } + + auto norm = std::accumulate(freq.cbegin(), freq.cend(), .0f); + for (auto& f : freq) { + f /= norm; + } + norm = std::accumulate(feature_weights.cbegin(), feature_weights.cend(), .0f); + for (auto& f : feature_weights) { + f /= norm; + } + + for (size_t i = 0; i < feature_weights.size(); ++i) { + EXPECT_NEAR(freq[i], feature_weights[i], 1e-2); + } +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index fd5c9f43fb2a..153cafb88fd8 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -204,12 +204,11 @@ TEST(GpuHist, EvaluateRootSplit) { ASSERT_EQ(maker.hist.Data().size(), hist.size()); thrust::copy(hist.begin(), hist.end(), maker.hist.Data().begin()); + std::vector feature_weights; - maker.column_sampler.Init(kNCols, - param.colsample_bynode, - param.colsample_bylevel, - param.colsample_bytree, - false); + maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode, + param.colsample_bylevel, param.colsample_bytree, + false); RegTree tree; MetaInfo info; diff --git a/tests/python/test_demos.py b/tests/python/test_demos.py index 8b6535dbff45..25c1c4de6c1f 100644 --- a/tests/python/test_demos.py +++ b/tests/python/test_demos.py @@ -1,12 +1,10 @@ import os import subprocess -import sys import pytest import testing as tm -CURRENT_DIR = os.path.dirname(__file__) -ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR)) +ROOT_DIR = tm.PROJECT_ROOT DEMO_DIR = os.path.join(ROOT_DIR, 'demo') PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python') @@ -19,21 +17,27 @@ def test_basic_walkthrough(): os.remove('dump.raw.txt') +@pytest.mark.skipif(**tm.no_matplotlib()) def test_custom_multiclass_objective(): script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py') cmd = ['python', script, '--plot=0'] subprocess.check_call(cmd) +@pytest.mark.skipif(**tm.no_matplotlib()) def test_custom_rmsle_objective(): - major, minor = sys.version_info[:2] - if minor < 6: - pytest.skip('Skipping RMLSE test due to Python version being too low.') script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py') cmd = ['python', script, '--plot=0'] subprocess.check_call(cmd) +@pytest.mark.skipif(**tm.no_matplotlib()) +def test_feature_weights_demo(): + script = os.path.join(PYTHON_DEMO_DIR, 'feature_weights.py') + cmd = ['python', script, '--plot=0'] + subprocess.check_call(cmd) + + @pytest.mark.skipif(**tm.no_sklearn()) def test_sklearn_demo(): script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py') diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index ecf5f60411bf..46d22185b1ad 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -99,6 +99,9 @@ def test_slice(self): X = rng.randn(100, 100) y = rng.randint(low=0, high=3, size=100) d = xgb.DMatrix(X, y) + fw = rng.uniform(size=100) + d.feature_weights = fw + eval_res_0 = {} booster = xgb.train( {'num_class': 3, 'objective': 'multi:softprob'}, d, @@ -109,13 +112,17 @@ def test_slice(self): d.set_base_margin(predt) ridxs = [1, 2, 3, 4, 5, 6] - d = d.slice(ridxs) - sliced_margin = d.get_float_info('base_margin') + sliced = d.slice(ridxs) + np.testing.assert_equal(sliced.feature_weights, d.feature_weights) + + sliced = d.slice(ridxs) + sliced_margin = sliced.get_float_info('base_margin') assert sliced_margin.shape[0] == len(ridxs) * 3 eval_res_1 = {} - xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d, - num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1) + xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced, + num_boost_round=2, evals=[(sliced, 'd')], + evals_result=eval_res_1) eval_res_0 = eval_res_0['d']['merror'] eval_res_1 = eval_res_1['d']['merror'] @@ -196,6 +203,25 @@ def test_get_info(self): dtrain.get_float_info('base_margin') dtrain.get_uint_info('group_ptr') + def test_feature_info(self): + kRows = 10 + kCols = 50 + rng = np.random.RandomState(1994) + fw = rng.uniform(size=kCols) + X = rng.randn(kRows, kCols) + m = xgb.DMatrix(X) + m.feature_weights = fw + np.testing.assert_allclose(fw, m.feature_weights) + # Handle None + m.feature_weights = None + assert m.feature_weights.shape[0] == 0 + + fw -= 1 + + def assign_weight(): + m.feature_weights = fw + self.assertRaises(ValueError, assign_weight) + def test_sparse_dmatrix_csr(self): nrow = 100 ncol = 1000 diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 7f62a3e83052..ce0b57e823ff 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1,3 +1,5 @@ +import collections +import importlib.util import numpy as np import xgboost as xgb from xgboost.sklearn import XGBoostLabelEncoder @@ -654,6 +656,7 @@ def test_validation_weights_xgbmodel(): eval_set=[(X_train, y_train), (X_test, y_test)], sample_weight_eval_set=[weights_train]) + def test_validation_weights_xgbclassifier(): from sklearn.datasets import make_hastie_10_2 @@ -920,6 +923,64 @@ def test_pandas_input(): np.array([0, 1])) +def run_feature_weights(increasing): + with TemporaryDirectory() as tmpdir: + kRows = 512 + kCols = 64 + colsample_bynode = 0.5 + reg = xgb.XGBRegressor(tree_method='hist', + colsample_bynode=colsample_bynode) + X = rng.randn(kRows, kCols) + y = rng.randn(kRows) + fw = np.ones(shape=(kCols,)) + for i in range(kCols): + if increasing: + fw[i] *= float(i) + else: + fw[i] *= float(kCols - i) + + reg.fit(X, y, feature_weights=fw) + model_path = os.path.join(tmpdir, 'model.json') + reg.save_model(model_path) + with open(model_path) as fd: + model = json.load(fd) + + parser_path = os.path.join(tm.PROJECT_ROOT, 'demo', 'json-model', + 'json_parser.py') + spec = importlib.util.spec_from_file_location("JsonParser", + parser_path) + foo = importlib.util.module_from_spec(spec) + spec.loader.exec_module(foo) + model = foo.Model(model) + splits = {} + total_nodes = 0 + for tree in model.trees: + n_nodes = len(tree.nodes) + total_nodes += n_nodes + for n in range(n_nodes): + if tree.is_leaf(n): + continue + if splits.get(tree.split_index(n), None) is None: + splits[tree.split_index(n)] = 1 + else: + splits[tree.split_index(n)] += 1 + + od = collections.OrderedDict(sorted(splits.items())) + tuples = [(k, v) for k, v in od.items()] + k, v = list(zip(*tuples)) + w = np.polyfit(k, v, deg=1) + return w + + +def test_feature_weights(): + poly_increasing = run_feature_weights(True) + poly_decreasing = run_feature_weights(False) + # Approxmated test, this is dependent on the implementation of random + # number generator in std library. + assert poly_increasing[0] > 0.08 + assert poly_decreasing[0] < -0.08 + + class TestBoostFromPrediction(unittest.TestCase): def run_boost_from_prediction(self, tree_method): from sklearn.datasets import load_breast_cancer From 8f285d1123f212d99f708e3c5d8ef453458e488b Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 12:35:16 +0800 Subject: [PATCH 2/9] Switch to set info. --- demo/guide-python/feature_weights.py | 2 +- include/xgboost/c_api.h | 25 +-------- include/xgboost/data.h | 3 -- python-package/xgboost/core.py | 48 +++-------------- python-package/xgboost/data.py | 78 ++++++++++++---------------- python-package/xgboost/sklearn.py | 6 +-- src/c_api/c_api.cc | 29 ++--------- src/data/data.cc | 36 +++++-------- 8 files changed, 62 insertions(+), 165 deletions(-) diff --git a/demo/guide-python/feature_weights.py b/demo/guide-python/feature_weights.py index 07a8719422c6..b9cee8c050af 100644 --- a/demo/guide-python/feature_weights.py +++ b/demo/guide-python/feature_weights.py @@ -22,7 +22,7 @@ def main(args): fw[i] *= float(i) dtrain = xgboost.DMatrix(X, y) - dtrain.feature_weights = fw + dtrain.set_info(feature_weights=fw) bst = xgboost.train({'tree_method': 'hist', 'colsample_bynode': 0.5}, diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index ddb01e7d98db..7ca9ece06a4f 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -502,29 +502,8 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, * * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, - void *data, bst_ulong size, int type); - -/*! - * \brief Get feature info in a thread local buffer. - * - * Caller is responsible for copying out the data, before next call to any API function of - * XGBoost. The data is always on CPU thread local storage. - * - * \param handle An instance of data matrix. - * \param field Field name. - * \param out_type Type of this field. This is defined in xgboost::DataType enum class. - * \param out_size Length of output data, this is relative to size of out_type. (Meaning - * NOT number of bytes.) - * \param out_dptr Pointer to output buffer. - * - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGDMatrixGetFeatureInfo(DMatrixHandle handle, - const char* field, - int* out_type, - bst_ulong* out_size, - const void** out_dptr); +XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, + void *data, bst_ulong size, int type); /*! * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix diff --git a/include/xgboost/data.h b/include/xgboost/data.h index fbd3da5c84b7..f74dbd2c5a76 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -163,9 +163,6 @@ class MetaInfo { void SetFeatureInfo(const char *key, const char **info, const bst_ulong size); void GetFeatureInfo(const char *field, std::vector* out_str_vecs) const; - void SetFeatureInfo(const char *field, const void *info, DataType type, - bst_ulong size); - void GetFeatureInfo(const char *field, DataType *out_type, std::vector* out) const; /* * \brief Extend with other MetaInfo. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 294a0ed5a482..beb2b6f1284a 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -455,7 +455,8 @@ def set_info(self, label_lower_bound=None, label_upper_bound=None, feature_names=None, - feature_types=None): + feature_types=None, + feature_weights=None): '''Set meta info for DMatrix.''' if label is not None: self.set_label(label) @@ -473,6 +474,12 @@ def set_info(self, self.feature_names = feature_names if feature_types is not None: self.feature_types = feature_types + if feature_weights is not None: + from .data import dispatch_meta_backend + # feature weight API is newer than the others and it accepts + # general data type instead of float only. + dispatch_meta_backend(matrix=self, data=feature_weights, + name='feature_weights') def get_float_info(self, field): """Get float property from the DMatrix. @@ -841,45 +848,6 @@ def feature_types(self, feature_types): None, c_bst_ulong(0))) - @property - def feature_weights(self): - '''Weight for each feature, defines the probability of each feature - being selected when colsample is being used. All values must - be greater than 0, otherwise a `XGBoostError` is thrown. - - .. versionadded:: 1.3.0 - - ''' - length = c_bst_ulong() - ret = ctypes.POINTER(ctypes.c_void_p)() - out_type = ctypes.c_int() - _check_call(_LIB.XGDMatrixGetFeatureInfo( - self.handle, - c_str('feature_weight'), - ctypes.byref(out_type), - ctypes.byref(length), - ctypes.byref(ret) - )) - to_data_type = {1: np.float32, 2: np.float64, 3: np.uint32, - 4: np.uint64} - to_c_type = {1: ctypes.c_float, 2: ctypes.c_double, 3: ctypes.c_uint32, - 4: ctypes.c_uint64} - dtype = to_data_type[out_type.value] - ptr = ctypes.cast(ret, ctypes.POINTER(to_c_type[out_type.value])) - return ctypes2numpy(ptr, length.value, dtype) - - @feature_weights.setter - def feature_weights(self, array): - '''Setter for feature weights. Clear the feature weights if array is - None. - - ''' - from .data import dispatch_meta_backend - if array is None: - array = np.empty((0, 0)) - dispatch_meta_backend(matrix=self, data=array, name='feature_weight', - is_feature=True) - class DeviceQuantileDMatrix(DMatrix): """Device memory Data Matrix used in XGBoost for training with diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 65f7e179a3a7..5eb98106793a 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -530,52 +530,39 @@ def dispatch_data_backend(data, missing, threads, raise TypeError('Not supported type for data.' + str(type(data))) -def _to_data_type(dtype: str): +def _to_data_type(dtype: str, name: str): dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4} + if dtype not in dtype_map.keys(): + raise TypeError( + f'Expecting float32, float64, uint32, uint64, got {dtype} ' + + f'for {name}.') return dtype_map[dtype] -def _meta_from_numpy(data, field, dtype, handle, is_feature: bool = False): - if is_feature: - data = _maybe_np_slice(data, dtype) - interface = data.__array_interface__ - assert interface.get('mask', None) is None - size = data.shape[0] - c_type = _to_data_type(str(data.dtype)) - data = interface['data'] - data = ctypes.c_void_p(data[0]) - _check_call(_LIB.XGDMatrixSetFeatureInfo( - handle, - c_str(field), - data, - size, - c_type - )) - else: - data = _maybe_np_slice(data, dtype) - if dtype == 'uint32': - c_data = c_array(ctypes.c_uint32, data) - _check_call(_LIB.XGDMatrixSetUIntInfo(handle, - c_str(field), - c_array(ctypes.c_uint, data), - c_bst_ulong(len(data)))) - elif dtype == 'float': - c_data = c_array(ctypes.c_float, data) - _check_call(_LIB.XGDMatrixSetFloatInfo(handle, - c_str(field), - c_data, - c_bst_ulong(len(data)))) - else: - raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field) - - -def _meta_from_list(data, field, dtype, handle, is_feature): +def _meta_from_numpy(data, field, dtype, handle): + data = _maybe_np_slice(data, dtype) + interface = data.__array_interface__ + assert interface.get('mask', None) is None, 'Masked array is not supported' + size = data.shape[0] + c_type = _to_data_type(str(data.dtype), field) + data = interface['data'] + data = ctypes.c_void_p(data[0]) + _check_call(_LIB.XGDMatrixSetDenseInfo( + handle, + c_str(field), + data, + size, + c_type + )) + + +def _meta_from_list(data, field, dtype, handle): data = np.array(data) - _meta_from_numpy(data, field, dtype, handle, is_feature) + _meta_from_numpy(data, field, dtype, handle) -def _meta_from_tuple(data, field, dtype, handle, is_feature): - return _meta_from_list(data, field, dtype, handle, is_feature) +def _meta_from_tuple(data, field, dtype, handle): + return _meta_from_list(data, field, dtype, handle) def _meta_from_cudf_df(data, field, handle): @@ -613,29 +600,28 @@ def _meta_from_dt(data, field, dtype, handle): _meta_from_numpy(data, field, dtype, handle) -def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None, - is_feature: bool = False): +def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): '''Dispatch for meta info.''' handle = matrix.handle if data is None: return if _is_list(data): - _meta_from_list(data, name, dtype, handle, is_feature) + _meta_from_list(data, name, dtype, handle) return if _is_tuple(data): - _meta_from_tuple(data, name, dtype, handle, is_feature) + _meta_from_tuple(data, name, dtype, handle) return if _is_numpy_array(data): - _meta_from_numpy(data, name, dtype, handle, is_feature) + _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_df(data): data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) - _meta_from_numpy(data, name, dtype, handle, is_feature) + _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_series(data): data = data.values.astype('float') assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 - _meta_from_numpy(data, name, dtype, handle, is_feature) + _meta_from_numpy(data, name, dtype, handle) return if _is_dlpack(data): data = _transform_dlpack(data) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 6c666961f432..c6c34dce1c99 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -511,7 +511,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) - train_dmatrix.feature_weights = feature_weights + train_dmatrix.set_info(feature_weights=feature_weights) evals_result = {} @@ -828,7 +828,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) - train_dmatrix.feature_weights = feature_weights + train_dmatrix.set_info(feature_weights=feature_weights) self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(), @@ -1217,7 +1217,7 @@ def _dmat_init(group, **params): train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs) - train_dmatrix.feature_weights = feature_weights + train_dmatrix.set_info(feature_weights=feature_weights) train_dmatrix.set_group(group) evals_result = {} diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 54e675bb1883..75dc8a4bf07b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -316,34 +316,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, API_END(); } -XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, - void *data, xgboost::bst_ulong size, - int type) { +XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, + void *data, xgboost::bst_ulong size, + int type) { API_BEGIN(); CHECK_HANDLE(); auto &info = static_cast *>(handle)->get()->Info(); - info.SetFeatureInfo(field, data, static_cast(type), size); - API_END(); -} - -XGB_DLL int XGDMatrixGetFeatureInfo(DMatrixHandle handle, - const char* field, - int* out_type, - xgboost::bst_ulong* out_len, - const void** out_dptr) { - API_BEGIN(); - CHECK_HANDLE(); - auto m = *static_cast*>(handle); - auto &info = static_cast *>(handle)->get()->Info(); - DataType out_t; - - // Right now only float for feature weights is used. If other data types are rquired, - // we can pass in a general memory buffer. - auto& float_vec = m->GetThreadLocal().ret_vec_float; - info.GetFeatureInfo(field, &out_t, &float_vec); - *out_type = static_cast(out_t); - *out_len = float_vec.size(); - *out_dptr = float_vec.data(); + info.SetInfo(field, data, static_cast(type), size); API_END(); } diff --git a/src/data/data.cc b/src/data/data.cc index d446cc995a43..677812ebba7a 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -380,6 +380,16 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t labels.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, labels.begin())); + } else if (!std::strcmp(key, "feature_weights")) { + auto &h_feature_weights = feature_weigths.HostVector(); + h_feature_weights.resize(num); + DISPATCH_CONST_PTR( + dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin())); + bool valid = + std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(), + [](float w) { return w >= 0; }); + CHECK(valid) << "Feature weight must be greater than 0."; } else { LOG(FATAL) << "Unknown key for MetaInfo: " << key; } @@ -399,6 +409,8 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype, vec = &this->labels_lower_bound_.HostVector(); } else if (!std::strcmp(key, "label_upper_bound")) { vec = &this->labels_upper_bound_.HostVector(); + } else if (!std::strcmp(key, "feature_weights")) { + vec = &this->feature_weigths.HostVector(); } else { LOG(FATAL) << "Unknown float field name: " << key; } @@ -455,30 +467,6 @@ void MetaInfo::GetFeatureInfo(const char *field, } } -void MetaInfo::SetFeatureInfo(const char *c_field, const void *info, DataType type, - bst_ulong size) { - std::string field {c_field}; - CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info."; - auto& h_feature_weights = feature_weigths.HostVector(); - h_feature_weights.resize(size); - DISPATCH_CONST_PTR( - type, info, cast_dptr, - std::copy(cast_dptr, cast_dptr + size, h_feature_weights.begin())); - bool valid = std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { return w >= 0; }); - CHECK(valid) << "Feature weight must be greater than 0."; -} - -void MetaInfo::GetFeatureInfo(const char *c_field, DataType *out_type, - std::vector *out) const { - std::string field {c_field}; - CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info."; - *out_type = DataType::kFloat32; - out->resize(feature_weigths.Size()); - auto const& h_feature_weights = feature_weigths.ConstHostVector(); - std::copy(h_feature_weights.cbegin(), h_feature_weights.cend(), out->begin()); -} - void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { if (accumulate_rows) { this->num_row_ += that.num_row_; From 199ef5627301b5d96c9ec8382fb837cee5ea003f Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:05:25 +0800 Subject: [PATCH 3/9] Debug. --- src/c_api/c_api.cc | 1 + tests/python/test_dmatrix.py | 55 +++++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 75dc8a4bf07b..397f83e69bf8 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -322,6 +322,7 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, API_BEGIN(); CHECK_HANDLE(); auto &info = static_cast *>(handle)->get()->Info(); + CHECK(type >= 1 && type <= 4); info.SetInfo(field, data, static_cast(type), size); API_END(); } diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 46d22185b1ad..1bde47ee097c 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -96,11 +96,18 @@ def test_np_view(self): assert (from_view == from_array).all() def test_slice(self): + # 2887052510386eb7d12e09c859529a4f2b01ba35b847c807f95cbaddbb5eea7a X = rng.randn(100, 100) + # 85a40616f6748d98e59baf2961b655e0ebb2fc8ac298fc638a173e434073a0f9 y = rng.randint(low=0, high=3, size=100) + # ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef + # failed: + # 33d8eee9310c7f6ff57f0c876b5977f39f7e450eb7c71513cc4c133c57921a6 d = xgb.DMatrix(X, y) - fw = rng.uniform(size=100) - d.feature_weights = fw + np.testing.assert_equal(d.get_label(), y.astype(np.float32)) + + # fw = rng.uniform(size=100).astype(np.float32) + # d.set_info(feature_weights=fw) eval_res_0 = {} booster = xgb.train( @@ -109,16 +116,33 @@ def test_slice(self): predt = booster.predict(d) predt = predt.reshape(100 * 3, 1) + + i = 0 + import os + while os.path.exists(f'test_predict-{i}.txt'): + i += 1 + with open(f'test_predict-{i}.txt', 'w') as fd: + print(predt, 'pred', file=fd) + d.set_base_margin(predt) ridxs = [1, 2, 3, 4, 5, 6] + # failed: + # f3459dc754e30ff09e91c9660789cef53d998d6256f8ee81cc9c304cc54fbf40 + # passed: + # ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef sliced = d.slice(ridxs) - np.testing.assert_equal(sliced.feature_weights, d.feature_weights) + # np.testing.assert_equal(sliced.get_float_info('feature_weights'), fw) - sliced = d.slice(ridxs) sliced_margin = sliced.get_float_info('base_margin') assert sliced_margin.shape[0] == len(ridxs) * 3 + i = 0 + while os.path.exists(f'test_slice-{i}.dmatrix'): + i += 1 + d.save_binary(f'd_test_slice-{i}.dmatrix') + sliced.save_binary(f'test_slice-{i}.dmatrix') + eval_res_1 = {} xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced, num_boost_round=2, evals=[(sliced, 'd')], @@ -126,6 +150,10 @@ def test_slice(self): eval_res_0 = eval_res_0['d']['merror'] eval_res_1 = eval_res_1['d']['merror'] + + np.savetxt('test_sliced-X.txt', X) + np.savetxt('test_sliced-y.txt', y) + for i in range(len(eval_res_0)): assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02 @@ -203,23 +231,24 @@ def test_get_info(self): dtrain.get_float_info('base_margin') dtrain.get_uint_info('group_ptr') - def test_feature_info(self): + def test_feature_weights(self): kRows = 10 kCols = 50 rng = np.random.RandomState(1994) fw = rng.uniform(size=kCols) X = rng.randn(kRows, kCols) m = xgb.DMatrix(X) - m.feature_weights = fw - np.testing.assert_allclose(fw, m.feature_weights) - # Handle None - m.feature_weights = None - assert m.feature_weights.shape[0] == 0 + m.set_info(feature_weights=fw) + np.testing.assert_allclose(fw, m.get_float_info('feature_weights')) + # Handle empty + m.set_info(feature_weights=np.empty((0, 0))) + + assert m.get_float_info('feature_weights').shape[0] == 0 fw -= 1 def assign_weight(): - m.feature_weights = fw + m.set_info(feature_weights=fw) self.assertRaises(ValueError, assign_weight) def test_sparse_dmatrix_csr(self): @@ -228,7 +257,7 @@ def test_sparse_dmatrix_csr(self): x = rand(nrow, ncol, density=0.0005, format='csr', random_state=rng) assert x.indices.max() < ncol - 1 x.data[:] = 1 - dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow)) + dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow)) assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) watchlist = [(dtrain, 'train')] param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0} @@ -241,7 +270,7 @@ def test_sparse_dmatrix_csc(self): x = rand(nrow, ncol, density=0.0005, format='csc', random_state=rng) assert x.indices.max() < nrow - 1 x.data[:] = 1 - dtrain = xgb.DMatrix(x, label=np.random.binomial(1, 0.3, nrow)) + dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow)) assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) watchlist = [(dtrain, 'train')] param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0} From 7980e84af1d5f23426ef819d7dda48c238e09378 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:18:04 +0800 Subject: [PATCH 4/9] Doc. --- doc/parameter.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index 685bddbc815d..7e7e774a2bfa 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -108,7 +108,8 @@ Parameters for Tree Booster each split. On Python interface, one can set the ``feature_weights`` for DMatrix to define the - probability of each feature being selected when using column sampling. + probability of each feature being selected when using column sampling. There's a + similar parameter for ``fit`` method in sklearn interface. * ``lambda`` [default=1, alias: ``reg_lambda``] @@ -227,7 +228,7 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See tutorial for more information -Additional parameters for ``hist`` and ```gpu_hist`` tree method +Additional parameters for ``hist`` and ``gpu_hist`` tree method ================================================================ * ``single_precision_histogram``, [default=``false``] From afe11e429f55b8440a424d77d503ccc2de984a73 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:34:17 +0800 Subject: [PATCH 5/9] Fix the new API. --- python-package/xgboost/data.py | 17 +++++++++++++---- tests/python/test_dmatrix.py | 14 ++------------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 5eb98106793a..e4c05dcc244e 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -539,19 +539,27 @@ def _to_data_type(dtype: str, name: str): return dtype_map[dtype] +def _validate_meta_shape(data): + if hasattr(data, 'shape'): + assert len(data.shape) == 1 or ( + len(data.shape) == 2 and + (data.shape[1] == 0 or data.shape[1] == 1)) + + def _meta_from_numpy(data, field, dtype, handle): data = _maybe_np_slice(data, dtype) interface = data.__array_interface__ assert interface.get('mask', None) is None, 'Masked array is not supported' size = data.shape[0] + c_type = _to_data_type(str(data.dtype), field) - data = interface['data'] - data = ctypes.c_void_p(data[0]) + ptr = interface['data'][0] + ptr = ctypes.c_void_p(ptr) _check_call(_LIB.XGDMatrixSetDenseInfo( handle, c_str(field), - data, - size, + ptr, + c_bst_ulong(size), c_type )) @@ -603,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle): def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): '''Dispatch for meta info.''' handle = matrix.handle + _validate_meta_shape(data) if data is None: return if _is_list(data): diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 1bde47ee097c..a20a576bf047 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -96,18 +96,13 @@ def test_np_view(self): assert (from_view == from_array).all() def test_slice(self): - # 2887052510386eb7d12e09c859529a4f2b01ba35b847c807f95cbaddbb5eea7a X = rng.randn(100, 100) - # 85a40616f6748d98e59baf2961b655e0ebb2fc8ac298fc638a173e434073a0f9 y = rng.randint(low=0, high=3, size=100) - # ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef - # failed: - # 33d8eee9310c7f6ff57f0c876b5977f39f7e450eb7c71513cc4c133c57921a6 d = xgb.DMatrix(X, y) np.testing.assert_equal(d.get_label(), y.astype(np.float32)) - # fw = rng.uniform(size=100).astype(np.float32) - # d.set_info(feature_weights=fw) + fw = rng.uniform(size=100).astype(np.float32) + d.set_info(feature_weights=fw) eval_res_0 = {} booster = xgb.train( @@ -127,12 +122,7 @@ def test_slice(self): d.set_base_margin(predt) ridxs = [1, 2, 3, 4, 5, 6] - # failed: - # f3459dc754e30ff09e91c9660789cef53d998d6256f8ee81cc9c304cc54fbf40 - # passed: - # ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef sliced = d.slice(ridxs) - # np.testing.assert_equal(sliced.get_float_info('feature_weights'), fw) sliced_margin = sliced.get_float_info('base_margin') assert sliced_margin.shape[0] == len(ridxs) * 3 From 171b062e7f2a1345c50837047998994c846df12e Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:38:36 +0800 Subject: [PATCH 6/9] Missing debug. --- tests/python/test_dmatrix.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index a20a576bf047..f641ea2c54f4 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -112,13 +112,6 @@ def test_slice(self): predt = booster.predict(d) predt = predt.reshape(100 * 3, 1) - i = 0 - import os - while os.path.exists(f'test_predict-{i}.txt'): - i += 1 - with open(f'test_predict-{i}.txt', 'w') as fd: - print(predt, 'pred', file=fd) - d.set_base_margin(predt) ridxs = [1, 2, 3, 4, 5, 6] @@ -127,12 +120,6 @@ def test_slice(self): sliced_margin = sliced.get_float_info('base_margin') assert sliced_margin.shape[0] == len(ridxs) * 3 - i = 0 - while os.path.exists(f'test_slice-{i}.dmatrix'): - i += 1 - d.save_binary(f'd_test_slice-{i}.dmatrix') - sliced.save_binary(f'test_slice-{i}.dmatrix') - eval_res_1 = {} xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced, num_boost_round=2, evals=[(sliced, 'd')], @@ -141,9 +128,6 @@ def test_slice(self): eval_res_0 = eval_res_0['d']['merror'] eval_res_1 = eval_res_1['d']['merror'] - np.savetxt('test_sliced-X.txt', X) - np.savetxt('test_sliced-y.txt', y) - for i in range(len(eval_res_0)): assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02 From cbd8613941e0c2efe1cefab3c281a047838a81e8 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:51:40 +0800 Subject: [PATCH 7/9] Add GPU support. --- src/data/data.cu | 24 +++++++++++++++++++ .../test_device_quantile_dmatrix.py | 14 +++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/data/data.cu b/src/data/data.cu index 5e63a828c207..15260498734d 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { std::partial_sum(out->begin(), out->end(), out->begin()); } +namespace { +// thrust::all_of tries to copy lambda function. +struct AllOfOp { + __device__ bool operator()(float w) { + return w >= 0; + } +}; +} // anonymous namespace + void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); auto const& j_arr = get(j_interface); @@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { } else if (key == "group") { CopyGroupInfoImpl(array_interface, &group_ptr_); return; + } else if (key == "label_lower_bound") { + CopyInfoImpl(array_interface, &labels_lower_bound_); + return; + } else if (key == "label_upper_bound") { + CopyInfoImpl(array_interface, &labels_upper_bound_); + return; + } else if (key == "feature_weights") { + CopyInfoImpl(array_interface, &feature_weigths); + auto d_feature_weights = feature_weigths.ConstDeviceSpan(); + auto valid = + thrust::all_of(thrust::device, d_feature_weights.data(), + d_feature_weights.data() + d_feature_weights.size(), + AllOfOp{}); + CHECK(valid) << "Feature weight must be greater than 0."; + return; } else { LOG(FATAL) << "Unknown metainfo: " << key; } diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index f0978a0afaf4..c44de28bd2ff 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -16,6 +16,20 @@ def test_dmatrix_numpy_init(self): match='is not supported for DeviceQuantileDMatrix'): xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) + @pytest.mark.skipif(**tm.no_cupy()) + def test_dmatrix_feature_weights(self): + import cupy as cp + rng = cp.random.RandomState(1994) + data = rng.randn(5, 5) + m = xgb.DMatrix(data) + + feature_weights = rng.uniform(size=5) + m.set_info(feature_weights=feature_weights) + + cp.testing.assert_array_equal( + cp.array(m.get_float_info('feature_weights')), + feature_weights.astype(np.float32)) + @pytest.mark.skipif(**tm.no_cupy()) def test_dmatrix_cupy_init(self): import cupy as cp From 050cbdad43a5030b443b2d74c28101f0fb7cc1ca Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 14:15:21 +0800 Subject: [PATCH 8/9] Revise C doc. --- include/xgboost/c_api.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 7ca9ece06a4f..4db461d11b1c 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -484,9 +484,15 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, const char ***out_features); /*! - * \brief Set feature info that's not strings. Currently accepted fields are: - * - * - feature_weight + * \brief Set meta info from dense matrix. Valid field names are: + * + * - label + * - weight + * - base_margin + * - group + * - label_lower_bound + * - label_upper_bound + * - feature_weights * * \param handle An instance of data matrix * \param field Feild name From 0d230932aebce7b490c1fd560e387af37325d8a0 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 18 Aug 2020 18:43:21 +0800 Subject: [PATCH 9/9] References. --- python-package/xgboost/core.py | 2 -- src/common/random.h | 7 +++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index beb2b6f1284a..cf22453245ac 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -476,8 +476,6 @@ def set_info(self, self.feature_types = feature_types if feature_weights is not None: from .data import dispatch_meta_backend - # feature weight API is newer than the others and it accepts - # general data type instead of float only. dispatch_meta_backend(matrix=self, data=feature_weights, name='feature_weights') diff --git a/src/common/random.h b/src/common/random.h index 3842c975c50a..7fd461d22d0f 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -78,6 +78,13 @@ using GlobalRandomEngine = RandomEngine; */ GlobalRandomEngine& GlobalRandom(); // NOLINT(*) +/* + * Original paper: + * Weighted Random Sampling (2005; Efraimidis, Spirakis) + * + * Blog: + * https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ +*/ template std::vector WeightedSamplingWithoutReplacement( std::vector const &array, std::vector const &weights, size_t n) {