Skip to content

Commit

Permalink
[Breaking][jvm-packages] make classification model be xgboost-compati…
Browse files Browse the repository at this point in the history
…ble (#7896)
  • Loading branch information
wbo4958 committed May 14, 2022
1 parent 1b6538b commit 11e46e4
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
Expand Up @@ -463,7 +463,6 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
outputStream.writeInt(instance.numClasses)
instance._booster.saveModel(outputStream)
outputStream.close()
}
Expand All @@ -477,13 +476,22 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
override def load(path: String): XGBoostClassificationModel = {
implicit val sc = super.sparkSession.sparkContext


val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
val numClasses = dataInStream.readInt()

// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
// or the new xgboost compatible.
val numClasses = metadata.xgboostVersion.map { _ =>
implicit val format = DefaultFormats
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
// or else, XGBoost will throw exception.
// So it's safe to get numClass from meta data.
(metadata.params \ "numClass").extractOpt[Int].getOrElse(2)
}.getOrElse(dataInStream.readInt())

val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
Expand Down
Expand Up @@ -51,7 +51,8 @@ private[spark] object DefaultXGBoostParamsReader {
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String) {
metadataJson: String,
xgboostVersion: Option[String] = None) {

/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
Expand Down Expand Up @@ -108,8 +109,8 @@ private[spark] object DefaultXGBoostParamsReader {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}

Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
val xgboostVersion = (metadata \ "xgboostVersion").extractOpt[String]
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr, xgboostVersion)
}

private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
Expand Down
Expand Up @@ -22,8 +22,8 @@ import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{ParamPair, Params}
import org.json4s.jackson.JsonMethods._
import org.json4s.{JArray, JBool, JDouble, JField, JInt, JNothing, JObject, JString, JValue}

import JsonDSLXGBoost._
import ml.dmlc.xgboost4j.scala.spark

// This originates from apache-spark DefaultPramsWriter copy paste
private[spark] object DefaultXGBoostParamsWriter {
Expand Down Expand Up @@ -78,6 +78,7 @@ private[spark] object DefaultXGBoostParamsWriter {
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("xgboostVersion" -> spark.VERSION) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
Expand Down
Expand Up @@ -16,16 +16,19 @@

package ml.dmlc.xgboost4j.scala.spark

import java.io.{File, FileInputStream}

import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}

import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.apache.commons.io.IOUtils

import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler

class XGBoostClassifierSuite extends FunSuite with PerTest {
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {

protected val treeMethod: String = "auto"

Expand Down Expand Up @@ -391,4 +394,37 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
df1.show()
}

test("XGBoostClassificationModel should be compatible") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "tree_method" -> treeMethod)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.overwrite().save(modelPath)
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeModelPath)

assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeModelPath))
}

private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}

/** Executes the provided code block and then closes the resource */
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}

}

0 comments on commit 11e46e4

Please sign in to comment.