New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] make the model saved by pyspark compatible #8219
Changes from 8 commits
93f05c2
768235a
4b2ce00
a1e1bc7
ff265ab
c9e9020
5a37c54
cc653b6
fdbfa22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,34 +21,28 @@ def _get_or_create_tmp_dir(): | |
return xgb_tmp_dir | ||
|
||
|
||
def serialize_xgb_model(model): | ||
def save_model_to_json_file(model) -> str: | ||
""" | ||
Serialize the input model to a string. | ||
Save the input model to a local file in driver side and return the path. | ||
|
||
Parameters | ||
---------- | ||
model: | ||
an xgboost.XGBModel instance, such as | ||
xgboost.XGBClassifier or xgboost.XGBRegressor instance | ||
""" | ||
# TODO: change to use string io | ||
# Save the model to json format | ||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") | ||
model.save_model(tmp_file_name) | ||
with open(tmp_file_name, "r", encoding="utf-8") as f: | ||
ser_model_string = f.read() | ||
return ser_model_string | ||
return tmp_file_name | ||
|
||
|
||
def deserialize_xgb_model(ser_model_string, xgb_model_creator): | ||
def deserialize_xgb_model(model_string, xgb_model_creator): | ||
""" | ||
Deserialize an xgboost.XGBModel instance from the input ser_model_string. | ||
Deserialize an xgboost.XGBModel instance from the input model_string. | ||
""" | ||
xgb_model = xgb_model_creator() | ||
# TODO: change to use string io | ||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") | ||
with open(tmp_file_name, "w", encoding="utf-8") as f: | ||
f.write(ser_model_string) | ||
xgb_model.load_model(tmp_file_name) | ||
xgb_model.load_model(bytearray(model_string.encode("utf-8"))) | ||
return xgb_model | ||
|
||
|
||
|
@@ -222,11 +216,11 @@ def saveImpl(self, path): | |
""" | ||
xgb_model = self.instance._xgb_sklearn_model | ||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) | ||
model_save_path = os.path.join(path, "model.json") | ||
ser_xgb_model = serialize_xgb_model(xgb_model) | ||
_get_spark_session().createDataFrame( | ||
[(ser_xgb_model,)], ["xgb_sklearn_model"] | ||
).write.parquet(model_save_path) | ||
model_save_path = os.path.join(path, "model") | ||
xgb_model_file = save_model_to_json_file(xgb_model) | ||
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")` | ||
# can't be loaded by XGBoost directly. | ||
_get_spark_session().read.text(xgb_model_file).write.text(model_save_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This line is not correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use distributed FS API to copy local file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wow, right. you're correct, @WeichenXu123 Good findings. Could you point me to what is the "distributed FS API"? Really appreciate it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use this: But, this does not support DBFS (databricks filesystem), we need support databricks case as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The example code in the PR description
seems does not wok ? If the path is a distributed FS path ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WeichenXu123, I use the RDD to save the text file, it should work with all kinds of hadoop-compatible FS.. |
||
|
||
|
||
class SparkXGBModelReader(MLReader): | ||
|
@@ -252,14 +246,9 @@ def load(self, path): | |
xgb_sklearn_params = py_model._gen_xgb_params_dict( | ||
gen_xgb_sklearn_estimator_param=True | ||
) | ||
model_load_path = os.path.join(path, "model.json") | ||
model_load_path = os.path.join(path, "model") | ||
|
||
ser_xgb_model = ( | ||
_get_spark_session() | ||
.read.parquet(model_load_path) | ||
.collect()[0] | ||
.xgb_sklearn_model | ||
) | ||
ser_xgb_model = _get_spark_session().read.text(model_load_path).collect()[0][0] | ||
|
||
def create_xgb_model(): | ||
return self.cls._xgb_cls()(**xgb_sklearn_params) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import glob | ||
import logging | ||
import random | ||
import sys | ||
|
@@ -7,6 +8,8 @@ | |
import pytest | ||
import testing as tm | ||
|
||
import xgboost as xgb | ||
|
||
if tm.no_spark()["condition"]: | ||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) | ||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): | ||
|
@@ -30,7 +33,7 @@ | |
) | ||
from xgboost.spark.core import _non_booster_params | ||
|
||
from xgboost import XGBClassifier, XGBRegressor | ||
from xgboost import XGBClassifier, XGBModel, XGBRegressor | ||
|
||
from .utils import SparkTestCase | ||
|
||
|
@@ -62,7 +65,12 @@ def setUp(self): | |
# >>> reg2.fit(X, y) | ||
# >>> reg2.predict(X, ntree_limit=5) | ||
# array([0.22185266, 0.77814734], dtype=float32) | ||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} | ||
self.reg_params = { | ||
"max_depth": 5, | ||
"n_estimators": 10, | ||
"ntree_limit": 5, | ||
"max_bin": 9, | ||
} | ||
self.reg_df_train = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.0, 2.0, 3.0), 0), | ||
|
@@ -427,6 +435,14 @@ def setUp(self): | |
def get_local_tmp_dir(self): | ||
return self.tempdir + str(uuid.uuid4()) | ||
|
||
def assert_model_compatible(self, model: XGBModel, model_path: str): | ||
bst = xgb.Booster() | ||
path = glob.glob(f"{model_path}/**/*.txt", recursive=True)[0] | ||
bst.load_model(path) | ||
# The model is saved by XGBModel which will add an extra scikit_learn attribute | ||
bst.set_attr(scikit_learn=None) | ||
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a test to assert model file does not include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed |
||
|
||
def test_regressor_params_basic(self): | ||
py_reg = SparkXGBRegressor() | ||
self.assertTrue(hasattr(py_reg, "n_estimators")) | ||
|
@@ -591,7 +607,8 @@ def test_classifier_with_params(self): | |
) | ||
|
||
def test_regressor_model_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBRegressor(**self.reg_params) | ||
model = regressor.fit(self.reg_df_train) | ||
model.save(path) | ||
|
@@ -611,8 +628,11 @@ def test_regressor_model_save_load(self): | |
with self.assertRaisesRegex(AssertionError, "Expected class name"): | ||
SparkXGBClassifierModel.load(path) | ||
|
||
self.assert_model_compatible(model, tmp_dir) | ||
|
||
def test_classifier_model_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBClassifier(**self.cls_params) | ||
model = regressor.fit(self.cls_df_train) | ||
model.save(path) | ||
|
@@ -632,12 +652,15 @@ def test_classifier_model_save_load(self): | |
with self.assertRaisesRegex(AssertionError, "Expected class name"): | ||
SparkXGBRegressorModel.load(path) | ||
|
||
self.assert_model_compatible(model, tmp_dir) | ||
|
||
@staticmethod | ||
def _get_params_map(params_kv, estimator): | ||
return {getattr(estimator, k): v for k, v in params_kv.items()} | ||
|
||
def test_regressor_model_pipeline_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBRegressor() | ||
pipeline = Pipeline(stages=[regressor]) | ||
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) | ||
|
@@ -655,9 +678,11 @@ def test_regressor_model_pipeline_save_load(self): | |
row.prediction, row.expected_prediction_with_params, atol=1e-3 | ||
) | ||
) | ||
self.assert_model_compatible(model.stages[0], tmp_dir) | ||
|
||
def test_classifier_model_pipeline_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
classifier = SparkXGBClassifier() | ||
pipeline = Pipeline(stages=[classifier]) | ||
pipeline = pipeline.copy( | ||
|
@@ -677,6 +702,7 @@ def test_classifier_model_pipeline_save_load(self): | |
row.probability, row.expected_probability_with_params, atol=1e-3 | ||
) | ||
) | ||
self.assert_model_compatible(model.stages[0], tmp_dir) | ||
|
||
def test_classifier_with_cross_validator(self): | ||
xgb_classifer = SparkXGBClassifier() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some " \ " " in the json file which can't be loaded by xgboost. Do you want to check more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take a look tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis No need anymore, I just found another way to do it.