Skip to content

Commit

Permalink
Feature weights (dmlc#5962)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored and anusharamesh committed May 11, 2022
1 parent 311cc45 commit 9d132d8
Show file tree
Hide file tree
Showing 25 changed files with 516 additions and 106 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -69,6 +69,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/host_device_vector.cc"
Expand Down
49 changes: 49 additions & 0 deletions demo/guide-python/feature_weights.py
@@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.
.. versionadded:: 1.3.0
'''

import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse


def main(args):
rng = np.random.RandomState(1994)

kRows = 1000
kCols = 10

X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)

dtrain = xgboost.DMatrix(X, y)
dtrain.set_info(feature_weights=fw)

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')

if args.plot:
xgboost.plot_importance(bst)
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion demo/json-model/json_parser.py
Expand Up @@ -94,7 +94,7 @@ def __str__(self):

class Model:
'''Gradient boosted tree model.'''
def __init__(self, m: dict):
def __init__(self, model: dict):
'''Construct the Model from JSON object.
parameters
Expand Down
9 changes: 6 additions & 3 deletions doc/parameter.rst
Expand Up @@ -107,6 +107,10 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split.

On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling. There's a
similar parameter for ``fit`` method in sklearn interface.

* ``lambda`` [default=1, alias: ``reg_lambda``]

- L2 regularization term on weights. Increasing this value will make model more conservative.
Expand Down Expand Up @@ -225,9 +229,8 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information

Additional parameters for `hist` and 'gpu_hist' tree method
================================================

Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
* ``single_precision_histogram``, [default=``false``]

- Use single precision to build histograms instead of double precision.
Expand Down
28 changes: 28 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);

/*!
* \brief Set meta info from dense matrix. Valid field names are:
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);

/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
Expand Down
29 changes: 6 additions & 23 deletions include/xgboost/data.h
Expand Up @@ -89,34 +89,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/core.py
Expand Up @@ -455,7 +455,8 @@ def set_info(self,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None):
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
Expand All @@ -473,6 +474,10 @@ def set_info(self,
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')

def get_float_info(self, field):
"""Get float property from the DMatrix.
Expand Down
45 changes: 31 additions & 14 deletions python-package/xgboost/data.py
Expand Up @@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data)))


def _to_data_type(dtype: str, name: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
if dtype not in dtype_map.keys():
raise TypeError(
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
f'for {name}.')
return dtype_map[dtype]


def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]

c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
ptr,
c_bst_ulong(size),
c_type
))


def _meta_from_list(data, field, dtype, handle):
Expand Down Expand Up @@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
_validate_meta_shape(data)
if data is None:
return
if _is_list(data):
Expand Down
27 changes: 20 additions & 7 deletions python-package/xgboost/sklearn.py
Expand Up @@ -441,6 +441,7 @@ def load_model(self, fname):
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, xgb_model=None, sample_weight_eval_set=None,
feature_weights=None,
callbacks=None):
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model
Expand All @@ -459,9 +460,6 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst.
Expand Down Expand Up @@ -490,6 +488,13 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Expand All @@ -498,13 +503,15 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
"""
self.n_features_in_ = X.shape[1]

train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)

evals_result = {}

Expand Down Expand Up @@ -762,7 +769,7 @@ def __init__(self, objective="binary:logistic", **kwargs):
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ

evals_result = {}
Expand Down Expand Up @@ -824,6 +831,7 @@ def fit(self, X, y, sample_weight=None, base_margin=None,
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)

self._Booster = train(xgb_options, train_dmatrix,
self.get_num_boosting_rounds(),
Expand Down Expand Up @@ -1107,10 +1115,10 @@ def __init__(self, objective='rank:pairwise', **kwargs):
raise ValueError("please use XGBRanker for ranking task")

def fit(self, X, y, group, sample_weight=None, base_margin=None,
eval_set=None,
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None):
feature_weights=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker
Expand Down Expand Up @@ -1176,6 +1184,10 @@ def fit(self, X, y, group, sample_weight=None, base_margin=None,
xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost
model to be loaded before training (allows training continuation).
feature_weights: array_like
Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must be
greater than 0, otherwise a `ValueError` is thrown.
callbacks : list of callback functions
List of callback functions that are applied at end of each
iteration. It is possible to use predefined callbacks by using
Expand Down Expand Up @@ -1211,6 +1223,7 @@ def _dmat_init(group, **params):
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights)
train_dmatrix.set_group(group)

evals_result = {}
Expand Down
11 changes: 11 additions & 0 deletions src/c_api/c_api.cc
Expand Up @@ -316,6 +316,17 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END();
}

XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, xgboost::bst_ulong size,
int type) {
API_BEGIN();
CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}

XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
Expand Down

0 comments on commit 9d132d8

Please sign in to comment.