Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] re-write xgboost read/write #7956

Merged
merged 1 commit into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import scala.collection.{AbstractIterator, Iterator, mutable}

import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand All @@ -35,10 +36,8 @@ import org.apache.commons.logging.LogFactory

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -272,7 +271,7 @@ object PreXGBoost extends PreXGBoostProvider {

val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))

import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val cacheInfo = {
if (useExternalMemory) {
s"$appName-${TaskContext.get().stageId()}-dtest_cache-" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats
import scala.collection.{Iterator, mutable}

import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter

import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType

class XGBoostClassifier (
Expand Down Expand Up @@ -274,7 +272,7 @@ class XGBoostClassificationModel private[ml](
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
Expand Down Expand Up @@ -469,10 +467,8 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext

DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)

// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
Expand All @@ -495,18 +491,7 @@ object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel]
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "XGBoostClassificationModel")
val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)

// The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
// or the new xgboost compatible.
val numClasses = metadata.xgboostVersion.map { _ =>
implicit val format = DefaultFormats
// For binary:logistic, the numClass parameter can't be set to 2 or not be set.
// For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
// or else, XGBoost will throw exception.
// So it's safe to get numClass from meta data.
(metadata.params \ "numClass").extractOpt[Int].getOrElse(2)
}.getOrElse(dataInStream.readInt())

val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
val booster = SXGBoost.loadModel(dataInStream)
val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.{Iterator, mutable}

import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
import ml.dmlc.xgboost4j.scala.spark.utils.XGBoostWriter
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import org.apache.hadoop.fs.Path
Expand All @@ -30,9 +29,9 @@ import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.json4s.DefaultFormats

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
import org.apache.spark.sql.types.StructType

class XGBoostRegressor (
Expand Down Expand Up @@ -260,7 +259,7 @@ class XGBoostRegressionModel private[ml] (
* Note: The performance is not ideal, use it carefully!
*/
override def predict(features: Vector): Double = {
import DataUtils._
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val dm = new DMatrix(processMissingValues(
Iterator(features.asXGB),
$(missing),
Expand Down Expand Up @@ -384,8 +383,6 @@ object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
implicit val format = DefaultFormats
implicit val sc = super.sparkSession.sparkContext
DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
// Save model data
val dataPath = new Path(path, "data").toString
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,8 @@ package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.spark.util.Utils

import org.apache.spark.ml.param.{Param, ParamPair, Params}
import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
import org.json4s.jackson.JsonMethods.{compact, parse, render}
Expand Down

This file was deleted.