Skip to content

Commit

Permalink
Set feature_names and feature_types in jvm-packages (#9364)
Browse files Browse the repository at this point in the history
* 1. Add parameters to set feature names and feature types
2. Save feature names and feature types to native json model

* Change serialization and deserialization format to ubj.
  • Loading branch information
jinmfeng001 committed Jul 12, 2023
1 parent 3632242 commit a1367ea
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ private[scala] case class XGBoostExecutionParams(
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean,
treeMethod: Option[String],
isLocal: Boolean) {
isLocal: Boolean,
featureNames: Option[Array[String]],
featureTypes: Option[Array[String]]) {

private var rawParamMap: Map[String, Any] = _

Expand Down Expand Up @@ -213,14 +215,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
.asInstanceOf[Boolean]

val featureNames = if (overridedParams.contains("feature_names")) {
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
} else None
val featureTypes = if (overridedParams.contains("feature_types")){
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
} else None

val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
missing, allowNonZeroForMissing, trackerConf,
checkpointParam,
inputParams,
xgbExecEarlyStoppingParams,
cacheTrainingSet,
treeMethod,
isLocal)
isLocal,
featureNames,
featureTypes
)
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
Expand Down Expand Up @@ -531,6 +543,16 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)

if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}

if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}

new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}

Expand Down Expand Up @@ -643,6 +665,15 @@ private object Watches {
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)

if (xgbExecutionParams.featureNames.isDefined) {
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
}
if (xgbExecutionParams.featureTypes.isDefined) {
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
}

new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ class XGBoostClassifier (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)

def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)

def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)

// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class XGBoostRegressor (
def setSinglePrecisionHistogram(value: Boolean): this.type =
set(singlePrecisionHistogram, value)

def setFeatureNames(value: Array[String]): this.type =
set(featureNames, value)

def setFeatureTypes(value: Array[String]): this.type =
set(featureTypes, value)

// called at the start of fit/train when 'eval_metric' is not defined
private def setupDefaultEvalMetric(): String = {
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,21 @@ private[spark] trait GeneralParams extends Params {

final def getSeed: Long = $(seed)

/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
* In native code, the parameter name is feature_name.
* */
final val featureNames = new StringArrayParam(this, "feature_names",
"an array of feature names")

final def getFeatureNames: Array[String] = $(featureNames)

/** Feature types, q is numeric and c is categorical.
* In native code, the parameter name is feature_type
* */
final val featureTypes = new StringArrayParam(this, "feature_types",
"an array of feature types")

final def getFeatureTypes: Array[String] = $(featureTypes)
}

trait HasLeafPredictionCol extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils

import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.json4s.{DefaultFormats, Formats}
import org.json4s.jackson.parseJson

class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {

Expand Down Expand Up @@ -453,4 +455,26 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
}

test("native json model file should store feature_name and feature_type") {
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
val featureTypes = (1 to 33).map(idx => "q").toArray
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod
)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
.setFeatureNames(featureNames)
.setFeatureTypes(featureTypes)
val model = xgb.fit(trainingDF)
val modelStr = new String(model._booster.toByteArray("json"))
System.out.println(modelStr)
val jsonModel = parseJson(modelStr)
implicit val formats: Formats = DefaultFormats
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
assert(featureNamesInModel.length == 33)
assert(featureTypesInModel.length == 33)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,51 @@ public void setAttrs(Map<String, String> attrs) throws XGBoostError {
}
}

/**
* Get feature names from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureNames() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
return out;
}

/**
* Set feature names to the Booster.
*
* @param featureNames
* @throws XGBoostError
*/
public void setFeatureNames(String[] featureNames) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_name", featureNames));
}

/**
* Get feature types from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureTypes() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
return out;
}

/**
* Set feature types to the Booster.
* @param featureTypes
* @throws XGBoostError
*/
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_type", featureTypes));
}

/**
* Update the booster for one iteration.
*
Expand Down Expand Up @@ -744,7 +789,7 @@ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeInt(version);
out.writeObject(this.toByteArray());
out.writeObject(this.toByteArray("ubj"));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
Expand Down Expand Up @@ -780,7 +825,7 @@ public synchronized void dispose() {
@Override
public void write(Kryo kryo, Output output) {
try {
byte[] serObj = this.toByteArray();
byte[] serObj = this.toByteArray("ubj");
int serObjSize = serObj.length;
output.writeInt(serObjSize);
output.writeInt(version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ public static Booster trainAndSaveCheckpoint(
if (booster == null) {
// Start training on a new booster
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,8 @@ public final static native int XGQuantileDMatrixCreateFromCallback(
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);

public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);

public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);

}
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.setBaseMargin(column)
}

/**
* set feature names
* @param values feature names
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def setFeatureNames(values: Array[String]): Unit = {
jDMatrix.setFeatureNames(values)
}

/**
* set feature types
* @param values feature types
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def setFeatureTypes(values: Array[String]): Unit = {
jDMatrix.setFeatureTypes(values)
}

/**
* Get group sizes of DMatrix (used for ranking)
*/
Expand Down Expand Up @@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.getBaseMargin
}

/**
* get feature names
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def getFeatureNames: Array[String] = {
jDMatrix.getFeatureNames
}

/**
* get feature types
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def getFeatureTypes: Array[String] = {
jDMatrix.getFeatureTypes
}

/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
Expand Down
65 changes: 65 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,3 +1148,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
if (field) jenv->ReleaseStringUTFChars(jfield, field);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
jobjectArray jfeatures) {
BoosterHandle handle = (BoosterHandle)jhandle;

const char *field = jenv->GetStringUTFChars(jfield, 0);

bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);

std::vector<std::string> features;
std::vector<char const*> features_char;

for (bst_ulong i = 0; i < feature_num; ++i) {
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
const char *s = jenv->GetStringUTFChars(jfeature, 0);
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
}

for (size_t i = 0; i < features.size(); ++i) {
features_char.push_back(features[i].c_str());
}

int ret = XGBoosterSetStrFeatureInfo(
handle, field, dmlc::BeginPtr(features_char), feature_num);
JVM_CHECK_CALL(ret);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetGtrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
jobjectArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;

const char *field = jenv->GetStringUTFChars(jfield, 0);

bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);

const char **features;
std::vector<char *> features_char;

int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
(const char ***)&features);
JVM_CHECK_CALL(ret);

for (bst_ulong i = 0; i < feature_num; i++) {
jstring jfeature = jenv->NewStringUTF(features[i]);
jenv->SetObjectArrayElement(jout, i, jfeature);
}

return ret;
}
18 changes: 18 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a1367ea

Please sign in to comment.