Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] make the model saved by pyspark compatible #8219

Merged
merged 9 commits into from Sep 20, 2022
39 changes: 14 additions & 25 deletions python-package/xgboost/spark/model.py
Expand Up @@ -21,34 +21,28 @@ def _get_or_create_tmp_dir():
return xgb_tmp_dir


def serialize_xgb_model(model):
def dump_model_to_json_file(model) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the term save. Dump has a specific meaning in XGBoost's code base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""
Serialize the input model to a string.
Dump the input model to a local file in driver and return the path.

Parameters
----------
model:
an xgboost.XGBModel instance, such as
xgboost.XGBClassifier or xgboost.XGBRegressor instance
"""
# TODO: change to use string io
# Dump the model to json format
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
model.save_model(tmp_file_name)
with open(tmp_file_name, "r", encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
return tmp_file_name


def deserialize_xgb_model(ser_model_string, xgb_model_creator):
def deserialize_xgb_model(model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
Deserialize an xgboost.XGBModel instance from the input model_string.
"""
xgb_model = xgb_model_creator()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
xgb_model.load_model(tmp_file_name)
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
return xgb_model


Expand Down Expand Up @@ -222,11 +216,11 @@ def saveImpl(self, path):
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model.json")
ser_xgb_model = serialize_xgb_model(xgb_model)
_get_spark_session().createDataFrame(
[(ser_xgb_model,)], ["xgb_sklearn_model"]
).write.parquet(model_save_path)
model_save_path = os.path.join(path, "model")
xgb_model_file = dump_model_to_json_file(xgb_model)
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some " \ " " in the json file which can't be loaded by xgboost. Do you want to check more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis No need anymore, I just found another way to do it.

# can't be loaded by XGBoost directly.
_get_spark_session().read.text(xgb_model_file).write.text(model_save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_spark_session().read.text(xgb_model_file).

This line is not correct.
spark.read.text(path) the path must be a distributed file system path which all spark executor can access.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use distributed FS API to copy local file xgb_model_file into the model saved path (a hadoop FS path)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, right. you're correct, @WeichenXu123 Good findings. Could you point me to what is the "distributed FS API"? Really appreciate it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use this:
https://arrow.apache.org/docs/python/generated/pyarrow.fs.HadoopFileSystem.html

But, this does not support DBFS (databricks filesystem), we need support databricks case as well.
Databricks mount dbfs:/xxx/xxx to local file system /dbfs/xxx/xxx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example code in the PR description

import xgboost as xgb
bst = xgb.Booster()

# Basically, YOUR_MODEL_PATH should be like "xxxx/model/xxx.txt"
YOUR_MODEL_PATH="xxx"
bst.load_model(YOUR_MODEL_PATH)

seems does not wok ? If the path is a distributed FS path ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123, I use the RDD to save the text file, it should work with all kinds of hadoop-compatible FS..



class SparkXGBModelReader(MLReader):
Expand All @@ -252,14 +246,9 @@ def load(self, path):
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model.json")
model_load_path = os.path.join(path, "model")

ser_xgb_model = (
_get_spark_session()
.read.parquet(model_load_path)
.collect()[0]
.xgb_sklearn_model
)
ser_xgb_model = _get_spark_session().read.text(model_load_path).collect()[0][0]

def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
Expand Down
38 changes: 32 additions & 6 deletions tests/python/test_spark/test_spark_local.py
@@ -1,3 +1,4 @@
import glob
import logging
import random
import sys
Expand All @@ -7,6 +8,8 @@
import pytest
import testing as tm

import xgboost as xgb

if tm.no_spark()["condition"]:
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
Expand All @@ -30,7 +33,7 @@
)
from xgboost.spark.core import _non_booster_params

from xgboost import XGBClassifier, XGBRegressor
from xgboost import XGBClassifier, XGBModel, XGBRegressor

from .utils import SparkTestCase

Expand Down Expand Up @@ -62,7 +65,12 @@ def setUp(self):
# >>> reg2.fit(X, y)
# >>> reg2.predict(X, ntree_limit=5)
# array([0.22185266, 0.77814734], dtype=float32)
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
self.reg_params = {
"max_depth": 5,
"n_estimators": 10,
"ntree_limit": 5,
"max_bin": 9,
}
self.reg_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0),
Expand Down Expand Up @@ -427,6 +435,14 @@ def setUp(self):
def get_local_tmp_dir(self):
return self.tempdir + str(uuid.uuid4())

def assert_model_compatible(self, model: XGBModel, model_path: str):
bst = xgb.Booster()
path = glob.glob(f"{model_path}/**/*.txt", recursive=True)[0]
bst.load_model(path)
# The model is saved by XGBModel which will add an extra scikit_learn attribute
bst.set_attr(scikit_learn=None)
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test to assert model file does not include \n char.

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) or bst.load_model(path) will fail.


def test_regressor_params_basic(self):
py_reg = SparkXGBRegressor()
self.assertTrue(hasattr(py_reg, "n_estimators"))
Expand Down Expand Up @@ -591,7 +607,8 @@ def test_classifier_with_params(self):
)

def test_regressor_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBRegressor(**self.reg_params)
model = regressor.fit(self.reg_df_train)
model.save(path)
Expand All @@ -611,8 +628,11 @@ def test_regressor_model_save_load(self):
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBClassifierModel.load(path)

self.assert_model_compatible(model, tmp_dir)

def test_classifier_model_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBClassifier(**self.cls_params)
model = regressor.fit(self.cls_df_train)
model.save(path)
Expand All @@ -632,12 +652,15 @@ def test_classifier_model_save_load(self):
with self.assertRaisesRegex(AssertionError, "Expected class name"):
SparkXGBRegressorModel.load(path)

self.assert_model_compatible(model, tmp_dir)

@staticmethod
def _get_params_map(params_kv, estimator):
return {getattr(estimator, k): v for k, v in params_kv.items()}

def test_regressor_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
regressor = SparkXGBRegressor()
pipeline = Pipeline(stages=[regressor])
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
Expand All @@ -655,9 +678,11 @@ def test_regressor_model_pipeline_save_load(self):
row.prediction, row.expected_prediction_with_params, atol=1e-3
)
)
self.assert_model_compatible(model.stages[0], tmp_dir)

def test_classifier_model_pipeline_save_load(self):
path = "file:" + self.get_local_tmp_dir()
tmp_dir = self.get_local_tmp_dir()
path = "file:" + tmp_dir
classifier = SparkXGBClassifier()
pipeline = Pipeline(stages=[classifier])
pipeline = pipeline.copy(
Expand All @@ -677,6 +702,7 @@ def test_classifier_model_pipeline_save_load(self):
row.probability, row.expected_probability_with_params, atol=1e-3
)
)
self.assert_model_compatible(model.stages[0], tmp_dir)

def test_classifier_with_cross_validator(self):
xgb_classifer = SparkXGBClassifier()
Expand Down