Skip to content

Commit

Permalink
Feature weights (dmlc#5962) (#3)
Browse files Browse the repository at this point in the history
* Update BoosterParams.scala

* fix scala checkstyle error

* fix whitespace checkstyle error

* fix type cast error

* fix conversion issue

* update version for 1.2.5-al

* Feature weights (dmlc#5962)

* update version

* version update 2

Co-authored-by: Oscar Pan <oscar.pan@applovin.com>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
  • Loading branch information
3 people committed May 12, 2022
1 parent 311cc45 commit 5a34e41
Show file tree
Hide file tree
Showing 32 changed files with 543 additions and 123 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -69,6 +69,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/host_device_vector.cc"
Expand Down
49 changes: 49 additions & 0 deletions demo/guide-python/feature_weights.py
@@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.
.. versionadded:: 1.3.0
'''

import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse


def main(args):
rng = np.random.RandomState(1994)

kRows = 1000
kCols = 10

X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)

dtrain = xgboost.DMatrix(X, y)
dtrain.set_info(feature_weights=fw)

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')

if args.plot:
xgboost.plot_importance(bst)
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion demo/json-model/json_parser.py
Expand Up @@ -94,7 +94,7 @@ def __str__(self):

class Model:
'''Gradient boosted tree model.'''
def __init__(self, m: dict):
def __init__(self, model: dict):
'''Construct the Model from JSON object.
parameters
Expand Down
9 changes: 6 additions & 3 deletions doc/parameter.rst
Expand Up @@ -107,6 +107,10 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split.

On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling. There's a
similar parameter for ``fit`` method in sklearn interface.

* ``lambda`` [default=1, alias: ``reg_lambda``]

- L2 regularization term on weights. Increasing this value will make model more conservative.
Expand Down Expand Up @@ -225,9 +229,8 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information

Additional parameters for `hist` and 'gpu_hist' tree method
================================================

Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
* ``single_precision_histogram``, [default=``false``]

- Use single precision to build histograms instead of double precision.
Expand Down
28 changes: 28 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);

/*!
* \brief Set meta info from dense matrix. Valid field names are:
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);

/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
Expand Down
29 changes: 6 additions & 23 deletions include/xgboost/data.h
Expand Up @@ -89,34 +89,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down
2 changes: 1 addition & 1 deletion jvm-packages/pom.xml
Expand Up @@ -6,7 +6,7 @@

<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>pom</packaging>
<name>XGBoost JVM Package</name>
<description>JVM Package for XGBoost</description>
Expand Down
8 changes: 4 additions & 4 deletions jvm-packages/xgboost4j-example/pom.xml
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-example_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>jar</packaging>
<build>
<plugins>
Expand All @@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand All @@ -37,7 +37,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
6 changes: 3 additions & 3 deletions jvm-packages/xgboost4j-flink/pom.xml
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-flink_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<build>
<plugins>
<plugin>
Expand All @@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j-spark/pom.xml
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-spark_2.12</artifactId>
<build>
Expand All @@ -24,7 +24,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
Expand Up @@ -99,7 +99,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s

private val isLocal = sc.isLocal

private val overridedParams = overrideParams(rawParams, sc)
private val overridedParams: Map[String, Any] = overrideParams(rawParams, sc)

/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
Expand Down Expand Up @@ -213,7 +213,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.asInstanceOf[Double]
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
val featureWeights = overridedParams.getOrElse(
"feature_weights", new Array[Float](0)).asInstanceOf[Array[Float]]
"feature_weights", new Array[Double](0)).asInstanceOf[Array[Double]]
.map(_.toFloat)
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed, featureWeights)

val earlyStoppingRounds = overridedParams.getOrElse(
Expand Down
Expand Up @@ -17,8 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark.params

import scala.collection.immutable.HashSet

import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, DoubleParam, IntParam, Param, Params}

private[spark] trait BoosterParams extends Params {

Expand Down Expand Up @@ -110,6 +109,15 @@ private[spark] trait BoosterParams extends Params {

final def getSubsample: Double = $(subsample)

/**
* Probability distribution for column sampling. Doesn't have to be normalized
*/
final val featureWeights = new DoubleArrayParam(this, "featureWeights",
"probability distribution " +
"for feature sampling.", (value: Array[Double]) => true)

final def getFeatureWeights: Array[Double] = $(featureWeights)

/**
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
*/
Expand Down Expand Up @@ -286,7 +294,8 @@ private[spark] trait BoosterParams extends Params {
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0,
featureWeights -> new Array[Double](0))
}

private[spark] object BoosterParams {
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j/pom.xml
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>jar</packaging>

<dependencies>
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/core.py
Expand Up @@ -455,7 +455,8 @@ def set_info(self,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None):
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
Expand All @@ -473,6 +474,10 @@ def set_info(self,
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')

def get_float_info(self, field):
"""Get float property from the DMatrix.
Expand Down
45 changes: 31 additions & 14 deletions python-package/xgboost/data.py
Expand Up @@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data)))


def _to_data_type(dtype: str, name: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
if dtype not in dtype_map.keys():
raise TypeError(
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
f'for {name}.')
return dtype_map[dtype]


def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]

c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
ptr,
c_bst_ulong(size),
c_type
))


def _meta_from_list(data, field, dtype, handle):
Expand Down Expand Up @@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
_validate_meta_shape(data)
if data is None:
return
if _is_list(data):
Expand Down

0 comments on commit 5a34e41

Please sign in to comment.