Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature weights #5962

Merged
merged 9 commits into from Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -68,6 +68,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/quantile.cc"
Expand Down
49 changes: 49 additions & 0 deletions demo/guide-python/feature_weights.py
@@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.

.. versionadded:: 1.3.0
'''

import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse


def main(args):
rng = np.random.RandomState(1994)

kRows = 1000
kCols = 10

X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)

dtrain = xgboost.DMatrix(X, y)
dtrain.set_info(feature_weights=fw)

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')

if args.plot:
xgboost.plot_importance(bst)
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion demo/json-model/json_parser.py
Expand Up @@ -94,7 +94,7 @@ def __str__(self):

class Model:
'''Gradient boosted tree model.'''
def __init__(self, m: dict):
def __init__(self, model: dict):
'''Construct the Model from JSON object.

parameters
Expand Down
6 changes: 5 additions & 1 deletion doc/parameter.rst
Expand Up @@ -107,6 +107,10 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split.

On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling. There's a
similar parameter for ``fit`` method in sklearn interface.

* ``lambda`` [default=1, alias: ``reg_lambda``]

- L2 regularization term on weights. Increasing this value will make model more conservative.
Expand Down Expand Up @@ -224,7 +228,7 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information

Additional parameters for ``hist`` and ```gpu_hist`` tree method
Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================

* ``single_precision_histogram``, [default=``false``]
Expand Down
28 changes: 28 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);

/*!
* \brief Set meta info from dense matrix. Valid field names are:
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);

/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
Expand Down
29 changes: 6 additions & 23 deletions include/xgboost/data.h
Expand Up @@ -88,34 +88,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/core.py
Expand Up @@ -455,7 +455,8 @@ def set_info(self,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None):
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
Expand All @@ -473,6 +474,10 @@ def set_info(self,
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')

def get_float_info(self, field):
"""Get float property from the DMatrix.
Expand Down
45 changes: 31 additions & 14 deletions python-package/xgboost/data.py
Expand Up @@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data)))


def _to_data_type(dtype: str, name: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
if dtype not in dtype_map.keys():
raise TypeError(
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
f'for {name}.')
return dtype_map[dtype]


def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]

c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
ptr,
c_bst_ulong(size),
c_type
))


def _meta_from_list(data, field, dtype, handle):
Expand Down Expand Up @@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
_validate_meta_shape(data)
if data is None:
return
if _is_list(data):
Expand Down
27 changes: 20 additions & 7 deletions python-package/xgboost/sklearn.py
Expand Up @@ -441,6 +441,7 @@ def load_model(self, fname):
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, xgb_model=None, sample_weight_eval_set=None,
feature_weights=None,
callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model
Expand All @@ -459,9 +460,6 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst.
Expand Down Expand Up @@ -490,6 +488,13 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Expand All @@ -498,13 +503,15 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]

"""
self.n_features_in_ = X.shape[1]

train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)

evals_result = {}

Expand Down Expand Up @@ -759,7 +766,7 @@ def __init__(self, objective="binary:logistic", **kwargs):
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ

evals_result = {}
Expand Down Expand Up @@ -821,6 +828,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)

self._Booster = train(xgb_options, train_dmatrix,
self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1101,10 +1109,10 @@ def __init__(self, objective='rank:pairwise', **kwargs):
raise ValueError("please use XGBRanker for ranking task")

def fit(self, X, y, group, sample_weight=None, base_margin=None,
eval_set=None,
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None):
feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker

Expand Down Expand Up @@ -1170,6 +1178,10 @@ def fit(self, X, y, group, sample_weight=None, base_margin=None,
xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost
model to be loaded before training (allows training continuation).
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions
List of callback functions that are applied at end of each
iteration. It is possible to use predefined callbacks by using
Expand Down Expand Up @@ -1205,6 +1217,7 @@ def _dmat_init(group, **params):
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)
train_dmatrix.set_group(group)

evals_result = {}
Expand Down
11 changes: 11 additions & 0 deletions src/c_api/c_api.cc
Expand Up @@ -316,6 +316,17 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END();
}

XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, xgboost::bst_ulong size,
int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}

XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
Expand Down
12 changes: 12 additions & 0 deletions src/common/common.h
Expand Up @@ -9,12 +9,15 @@
#include <xgboost/base.h>
#include <xgboost/logging.h>

#include <algorithm>
#include <exception>
#include <functional>
#include <limits>
#include <type_traits>
#include <vector>
#include <string>
#include <sstream>
#include <numeric>

#if defined(__CUDACC__)
#include <thrust/system/cuda/error.h>
Expand Down Expand Up @@ -160,6 +163,15 @@ inline void AssertOneAPISupport() {
#endif // XGBOOST_USE_ONEAPI
}

template <typename Idx, typename V, typename Comp = std::less<V>>
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) {
std::vector<Idx> result(array.size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(
result.begin(), result.end(),
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
return result;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_