Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] make the model saved by pyspark compatible #8219

Merged
merged 9 commits into from Sep 20, 2022
45 changes: 10 additions & 35 deletions python-package/xgboost/spark/model.py
Expand Up @@ -21,34 +21,12 @@ def _get_or_create_tmp_dir():
return xgb_tmp_dir


def serialize_xgb_model(model):
def deserialize_xgb_model(model_string, xgb_model_creator):
"""
Serialize the input model to a string.

Parameters
----------
model:
an xgboost.XGBModel instance, such as
xgboost.XGBClassifier or xgboost.XGBRegressor instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
model.save_model(tmp_file_name)
with open(tmp_file_name, "r", encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string


def deserialize_xgb_model(ser_model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
Deserialize an xgboost.XGBModel instance from the input model_string.
"""
xgb_model = xgb_model_creator()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
xgb_model.load_model(tmp_file_name)
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
return xgb_model


Expand Down Expand Up @@ -222,11 +200,11 @@ def saveImpl(self, path):
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model.json")
ser_xgb_model = serialize_xgb_model(xgb_model)
_get_spark_session().createDataFrame(
[(ser_xgb_model,)], ["xgb_sklearn_model"]
).write.parquet(model_save_path)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
Comment on lines +205 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea, but how to control the saved file name ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the booster string contain "\n" character ? If yes, when loading back (by sparkContext.textFile(model_load_path), each line will become one RDD element, and these lines might be split into multiple RDD partitions)

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested, and It is always part-00000, seems there is a pattern for the generated file according to the task id since we only have 1 partition, so the id should be 00000

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document the file name "part-00000" is the model json file.

and pls add a test to ensure the model json file does not contain \n character and document the reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the code, the file name is defined by https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala#L225.

  override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = {
    val numfmt = NumberFormat.getInstance(Locale.US)
    numfmt.setMinimumIntegerDigits(5)
    numfmt.setGroupingUsed(false)

    val outputName = "part-" + numfmt.format(splitId)
    val path = FileOutputFormat.getOutputPath(getConf)
    val fs: FileSystem = {
      if (path != null) {
        path.getFileSystem(getConf)
      } else {
        // scalastyle:off FileSystemGet
        FileSystem.get(getConf)
        // scalastyle:on FileSystemGet
      }
    }
...

here the splitId is the TaskContext.partitionId(). In our case, there is only 1 partition, so the file name is "part-00000"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I know that. My point is can we customize the file name to make it more user-friendly.
Not a must though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the internal behavior of pyspark, not sure if it's a good idea to rely on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.

No need to do that, it makes code hard to maintain, your current code is fine.

)


class SparkXGBModelReader(MLReader):
Expand All @@ -252,13 +230,10 @@ def load(self, path):
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model.json")
model_load_path = os.path.join(path, "model")

ser_xgb_model = (
_get_spark_session()
.read.parquet(model_load_path)
.collect()[0]
.xgb_sklearn_model
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)

def create_xgb_model():
Expand Down
36 changes: 30 additions & 6 deletions tests/python/test_spark/test_spark_local.py
@@ -1,3 +1,4 @@
import glob
import logging
import random
import sys
Expand All @@ -7,6 +8,8 @@
import pytest
import testing as tm

import xgboost as xgb

if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
Expand All @@ -30,7 +33,7 @@
)
from xgboost.spark.core import _non_booster_params

from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBModel, XGBRegressor

from .utils import SparkTestCase

Expand Down Expand Up @@ -62,7 +65,12 @@ def setUp(self):
# >>> reg2.fit(X, y)
# >>> reg2.predict(X, ntree_limit=5)
# array([0.22185266, 0.77814734], dtype=float32)
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
self.reg_params = {
"max_depth": 5,
"n_estimators": 10,
"ntree_limit": 5,
"max_bin": 9,
}
self.reg_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
Expand Down Expand Up @@ -427,6 +435,12 @@ def setUp(self):
def get_local_tmp_dir(self):
return self.tempdir + str(uuid.uuid4())

def assert_model_compatible(self, model: XGBModel, model_path: str):
bst = xgb.Booster()
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0]
bst.load_model(path)
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test to assert model file does not include \n char.

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) or bst.load_model(path) will fail.


def test_regressor_params_basic(self):
py_reg = SparkXGBRegressor()
self.assertTrue(hasattr(py_reg, "n_estimators"))
Expand Down Expand Up @@ -591,7 +605,8 @@ def test_classifier_with_params(self):
)

def test_regressor_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBRegressor(**self.reg_params)
model = regressor.fit(self.reg_df_train)
model.save(path)
Expand All @@ -611,8 +626,11 @@ def test_regressor_model_save_load(self):
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBClassifierModel.load(path)

self.assert_model_compatible(model, tmp_dir)

def test_classifier_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBClassifier(**self.cls_params)
model = regressor.fit(self.cls_df_train)
model.save(path)
Expand All @@ -632,12 +650,15 @@ def test_classifier_model_save_load(self):
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBRegressorModel.load(path)

self.assert_model_compatible(model, tmp_dir)

@staticmethod
def _get_params_map(params_kv, estimator):
return {getattr(estimator, k): v for k, v in params_kv.items()}

def test_regressor_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBRegressor()
pipeline = Pipeline(stages=[regressor])
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
Expand All @@ -655,9 +676,11 @@ def test_regressor_model_pipeline_save_load(self):
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
self.assert_model_compatible(model.stages[0], tmp_dir)

def test_classifier_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
classifier = SparkXGBClassifier()
pipeline = Pipeline(stages=[classifier])
pipeline = pipeline.copy(
Expand All @@ -677,6 +700,7 @@ def test_classifier_model_pipeline_save_load(self):
row.probability, row.expected_probability_with_params, atol=1e-3
)
)
self.assert_model_compatible(model.stages[0], tmp_dir)

def test_classifier_with_cross_validator(self):
xgb_classifer = SparkXGBClassifier()
Expand Down