Skip to content

Commit

Permalink
[jvm-packages] add format option when saving a model (#7940)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed May 30, 2022
1 parent cc6d57a commit 6275cdc
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 32 deletions.
Expand Up @@ -30,6 +30,8 @@ import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}

import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter

import org.apache.spark.sql.types.StructType

class XGBoostClassifier (
Expand Down Expand Up @@ -462,7 +464,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
override def load(path: String): XGBoostClassificationModel = super.load(path)

private[XGBoostClassificationModel]
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel) extends MLWriter {
class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
extends XGBoostWriter {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
Expand All @@ -474,7 +477,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}
Expand Down
Expand Up @@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.{Iterator, mutable}

import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -379,7 +380,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
override def load(path: String): XGBoostRegressionModel = super.load(path)

private[XGBoostRegressionModel]
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends MLWriter {
class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
Expand All @@ -390,7 +391,7 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostRegressionModel")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
instance._booster.saveModel(outputStream)
instance._booster.saveModel(outputStream, getModelFormat())
outputStream.close()
}
}
Expand Down
@@ -0,0 +1,31 @@
/*
Copyright (c) 2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark.utils

import ml.dmlc.xgboost4j.java.{Booster => JBooster}

import org.apache.spark.ml.util.MLWriter

private[spark] abstract class XGBoostWriter extends MLWriter {

/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}

}
Expand Up @@ -16,16 +16,18 @@

package ml.dmlc.xgboost4j.scala.spark

import java.io.File
import java.io.{File, FileInputStream}

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}

import scala.math.min
import scala.util.Random

import org.apache.commons.io.IOUtils

trait PerTest extends BeforeAndAfterEach { self: FunSuite =>

protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
Expand Down Expand Up @@ -105,4 +107,22 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}


protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}

/** Executes the provided code block and then closes the resource */
protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}
}
Expand Up @@ -429,30 +429,29 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.overwrite().save(modelPath)
val nativeModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeModelPath)

val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeModelPath))
}

private def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
IOUtils.contentEquals(lfis, rfis)
}
}
}

/** Executes the provided code block and then closes the resource */
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
nativeJsonModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
}

}
Expand Up @@ -16,6 +16,8 @@

package ml.dmlc.xgboost4j.scala.spark

import java.io.File

import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}

import org.apache.spark.ml.linalg.{Vector, Vectors}
Expand All @@ -25,7 +27,7 @@ import org.scalatest.FunSuite

import org.apache.spark.ml.feature.VectorAssembler

class XGBoostRegressorSuite extends FunSuite with PerTest {
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"

test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
Expand Down Expand Up @@ -310,4 +312,42 @@ class XGBoostRegressorSuite extends FunSuite with PerTest {
val df1 = model.transform(vectorizedInput)
df1.show()
}

test("XGBoostRegressionModel should be compatible") {
val trainingDF = buildDataFrame(Regression.train)
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:squarederror",
"num_round" -> 5,
"tree_method" -> treeMethod,
"num_workers" -> numWorkers)

val model = new XGBoostRegressor(paramMap).fit(trainingDF)

val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
}

}
Expand Up @@ -34,6 +34,7 @@
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable, KryoSerializable {
public static final String DEFAULT_FORMAT = "deprecated";
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
Expand Down Expand Up @@ -391,7 +392,22 @@ public void saveModel(String modelPath) throws XGBoostError{
* @param out The output stream
*/
public void saveModel(OutputStream out) throws XGBoostError, IOException {
out.write(this.toByteArray());
saveModel(out, DEFAULT_FORMAT);
}

/**
* Save the model to file opened as output stream.
* The model format is compatible with other xgboost bindings.
* The output stream can only save one xgboost model.
* This function will close the OutputStream after the save.
*
* @param out The output stream
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError
* @throws IOException
*/
public void saveModel(OutputStream out, String format) throws XGBoostError, IOException {
out.write(this.toByteArray(format));
out.close();
}

Expand Down Expand Up @@ -643,7 +659,7 @@ public void setVersion(int version) {
* @throws XGBoostError native error
*/
public byte[] toByteArray() throws XGBoostError {
return this.toByteArray("deprecated");
return this.toByteArray(DEFAULT_FORMAT);
}

/**
Expand Down
Expand Up @@ -207,6 +207,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}

/**
* save model to Output stream
*
Expand All @@ -216,6 +217,18 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
def saveModel(out: java.io.OutputStream): Unit = {
booster.saveModel(out)
}

/**
* save model to Output stream
* @param out output stream
* @param format the supported model format, (json, ubj, deprecated)
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def saveModel(out: java.io.OutputStream, format: String): Unit = {
booster.saveModel(out, format)
}

/**
* Dump model as Array of string
*
Expand Down Expand Up @@ -315,7 +328,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
*/
@throws(classOf[XGBoostError])
def toByteArray: Array[Byte] = {
booster.toByteArray
booster.toByteArray()
}

/**
Expand Down

0 comments on commit 6275cdc

Please sign in to comment.