Skip to content

Commit

Permalink
Implement feature weights for column sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 12, 2020
1 parent ee70a23 commit 2cd1792
Show file tree
Hide file tree
Showing 23 changed files with 556 additions and 111 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -68,6 +68,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/quantile.cc"
Expand Down
49 changes: 49 additions & 0 deletions demo/guide-python/feature_weights.py
@@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.
.. versionadded:: 1.3.0
'''

import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse


def main(args):
rng = np.random.RandomState(1994)

kRows = 1000
kCols = 10

X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)

dtrain = xgboost.DMatrix(X, y)
dtrain.feature_weights = fw

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')

if args.plot:
xgboost.plot_importance(bst)
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion demo/json-model/json_parser.py
Expand Up @@ -94,7 +94,7 @@ def __str__(self):

class Model:
'''Gradient boosted tree model.'''
def __init__(self, m: dict):
def __init__(self, model: dict):
'''Construct the Model from JSON object.
parameters
Expand Down
3 changes: 3 additions & 0 deletions doc/parameter.rst
Expand Up @@ -107,6 +107,9 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split.

On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling.

* ``lambda`` [default=1, alias: ``reg_lambda``]

- L2 regularization term on weights. Increasing this value will make model more conservative.
Expand Down
43 changes: 43 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -483,6 +483,49 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);

/*!
* \brief Set feature info that's not strings. Currently accepted fields are:
*
* - feature_weight
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);

/*!
* \brief Get feature info in a thread local buffer.
*
* Caller is responsible for copying out the data, before next call to any API function of
* XGBoost. The data is always on CPU thread local storage.
*
* \param handle An instance of data matrix.
* \param field Field name.
* \param out_type Type of this field. This is defined in xgboost::DataType enum class.
* \param out_size Length of output data, this is relative to size of out_type. (Meaning
* NOT number of bytes.)
* \param out_dptr Pointer to output buffer.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixGetFeatureInfo(DMatrixHandle handle,
const char* field,
int* out_type,
bst_ulong* out_size,
const void** out_dptr);

/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
Expand Down
32 changes: 9 additions & 23 deletions include/xgboost/data.h
Expand Up @@ -89,34 +89,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down Expand Up @@ -181,6 +164,9 @@ class MetaInfo {

void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
void SetFeatureInfo(const char *field, const void *info, DataType type,
bst_ulong size);
void GetFeatureInfo(const char *field, DataType *out_type, std::vector<float>* out) const;

/*
* \brief Extend with other MetaInfo.
Expand Down
39 changes: 39 additions & 0 deletions python-package/xgboost/core.py
Expand Up @@ -841,6 +841,45 @@ def feature_types(self, feature_types):
None,
c_bst_ulong(0)))

@property
def feature_weights(self):
'''Weight for each feature, defines the probability of each feature
being selected when colsample is being used. All values must
be greater than 0, otherwise a `XGBoostError` is thrown.
.. versionadded:: 1.3.0
'''
length = c_bst_ulong()
ret = ctypes.POINTER(ctypes.c_void_p)()
out_type = ctypes.c_int()
_check_call(_LIB.XGDMatrixGetFeatureInfo(
self.handle,
c_str('feature_weight'),
ctypes.byref(out_type),
ctypes.byref(length),
ctypes.byref(ret)
))
to_data_type = {1: np.float32, 2: np.float64, 3: np.uint32,
4: np.uint64}
to_c_type = {1: ctypes.c_float, 2: ctypes.c_double, 3: ctypes.c_uint32,
4: ctypes.c_uint64}
dtype = to_data_type[out_type.value]
ptr = ctypes.cast(ret, ctypes.POINTER(to_c_type[out_type.value]))
return ctypes2numpy(ptr, length.value, dtype)

@feature_weights.setter
def feature_weights(self, array):
'''Setter for feature weights. Clear the feature weights if array is
None.
'''
from .data import dispatch_meta_backend
if array is None:
array = np.empty((0, 0))
dispatch_meta_backend(matrix=self, data=array, name='feature_weight',
is_feature=True)


class DeviceQuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with
Expand Down
72 changes: 47 additions & 25 deletions python-package/xgboost/data.py
Expand Up @@ -530,31 +530,52 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data)))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
def _to_data_type(dtype: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
return dtype_map[dtype]


def _meta_from_numpy(data, field, dtype, handle, is_feature: bool = False):
if is_feature:
data = _maybe_np_slice(data, dtype)
interface = data.__array_interface__
assert interface.get('mask', None) is None
size = data.shape[0]
c_type = _to_data_type(str(data.dtype))
data = interface['data']
data = ctypes.c_void_p(data[0])
_check_call(_LIB.XGDMatrixSetFeatureInfo(
handle,
c_str(field),
data,
size,
c_type
))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)


def _meta_from_list(data, field, dtype, handle):
def _meta_from_list(data, field, dtype, handle, is_feature):
data = np.array(data)
_meta_from_numpy(data, field, dtype, handle)
_meta_from_numpy(data, field, dtype, handle, is_feature)


def _meta_from_tuple(data, field, dtype, handle):
return _meta_from_list(data, field, dtype, handle)
def _meta_from_tuple(data, field, dtype, handle, is_feature):
return _meta_from_list(data, field, dtype, handle, is_feature)


def _meta_from_cudf_df(data, field, handle):
Expand Down Expand Up @@ -592,28 +613,29 @@ def _meta_from_dt(data, field, dtype, handle):
_meta_from_numpy(data, field, dtype, handle)


def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None,
is_feature: bool = False):
'''Dispatch for meta info.'''
handle = matrix.handle
if data is None:
return
if _is_list(data):
_meta_from_list(data, name, dtype, handle)
_meta_from_list(data, name, dtype, handle, is_feature)
return
if _is_tuple(data):
_meta_from_tuple(data, name, dtype, handle)
_meta_from_tuple(data, name, dtype, handle, is_feature)
return
if _is_numpy_array(data):
_meta_from_numpy(data, name, dtype, handle)
_meta_from_numpy(data, name, dtype, handle, is_feature)
return
if _is_pandas_df(data):
data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype)
_meta_from_numpy(data, name, dtype, handle)
_meta_from_numpy(data, name, dtype, handle, is_feature)
return
if _is_pandas_series(data):
data = data.values.astype('float')
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle)
_meta_from_numpy(data, name, dtype, handle, is_feature)
return
if _is_dlpack(data):
data = _transform_dlpack(data)
Expand Down

0 comments on commit 2cd1792

Please sign in to comment.