New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] make the model saved by pyspark compatible #8219
Changes from all commits
93f05c2
768235a
4b2ce00
a1e1bc7
ff265ab
c9e9020
5a37c54
cc653b6
fdbfa22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import glob | ||
import logging | ||
import random | ||
import sys | ||
|
@@ -7,6 +8,8 @@ | |
import pytest | ||
import testing as tm | ||
|
||
import xgboost as xgb | ||
|
||
if tm.no_spark()["condition"]: | ||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) | ||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"): | ||
|
@@ -30,7 +33,7 @@ | |
) | ||
from xgboost.spark.core import _non_booster_params | ||
|
||
from xgboost import XGBClassifier, XGBRegressor | ||
from xgboost import XGBClassifier, XGBModel, XGBRegressor | ||
|
||
from .utils import SparkTestCase | ||
|
||
|
@@ -62,7 +65,12 @@ def setUp(self): | |
# >>> reg2.fit(X, y) | ||
# >>> reg2.predict(X, ntree_limit=5) | ||
# array([0.22185266, 0.77814734], dtype=float32) | ||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} | ||
self.reg_params = { | ||
"max_depth": 5, | ||
"n_estimators": 10, | ||
"ntree_limit": 5, | ||
"max_bin": 9, | ||
} | ||
self.reg_df_train = self.session.createDataFrame( | ||
[ | ||
(Vectors.dense(1.0, 2.0, 3.0), 0), | ||
|
@@ -427,6 +435,12 @@ def setUp(self): | |
def get_local_tmp_dir(self): | ||
return self.tempdir + str(uuid.uuid4()) | ||
|
||
def assert_model_compatible(self, model: XGBModel, model_path: str): | ||
bst = xgb.Booster() | ||
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] | ||
bst.load_model(path) | ||
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a test to assert model file does not include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed |
||
|
||
def test_regressor_params_basic(self): | ||
py_reg = SparkXGBRegressor() | ||
self.assertTrue(hasattr(py_reg, "n_estimators")) | ||
|
@@ -591,7 +605,8 @@ def test_classifier_with_params(self): | |
) | ||
|
||
def test_regressor_model_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBRegressor(**self.reg_params) | ||
model = regressor.fit(self.reg_df_train) | ||
model.save(path) | ||
|
@@ -611,8 +626,11 @@ def test_regressor_model_save_load(self): | |
with self.assertRaisesRegex(AssertionError, "Expected class name"): | ||
SparkXGBClassifierModel.load(path) | ||
|
||
self.assert_model_compatible(model, tmp_dir) | ||
|
||
def test_classifier_model_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBClassifier(**self.cls_params) | ||
model = regressor.fit(self.cls_df_train) | ||
model.save(path) | ||
|
@@ -632,12 +650,15 @@ def test_classifier_model_save_load(self): | |
with self.assertRaisesRegex(AssertionError, "Expected class name"): | ||
SparkXGBRegressorModel.load(path) | ||
|
||
self.assert_model_compatible(model, tmp_dir) | ||
|
||
@staticmethod | ||
def _get_params_map(params_kv, estimator): | ||
return {getattr(estimator, k): v for k, v in params_kv.items()} | ||
|
||
def test_regressor_model_pipeline_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
regressor = SparkXGBRegressor() | ||
pipeline = Pipeline(stages=[regressor]) | ||
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) | ||
|
@@ -655,9 +676,11 @@ def test_regressor_model_pipeline_save_load(self): | |
row.prediction, row.expected_prediction_with_params, atol=1e-3 | ||
) | ||
) | ||
self.assert_model_compatible(model.stages[0], tmp_dir) | ||
|
||
def test_classifier_model_pipeline_save_load(self): | ||
path = "file:" + self.get_local_tmp_dir() | ||
tmp_dir = self.get_local_tmp_dir() | ||
path = "file:" + tmp_dir | ||
classifier = SparkXGBClassifier() | ||
pipeline = Pipeline(stages=[classifier]) | ||
pipeline = pipeline.copy( | ||
|
@@ -677,6 +700,7 @@ def test_classifier_model_pipeline_save_load(self): | |
row.probability, row.expected_probability_with_params, atol=1e-3 | ||
) | ||
) | ||
self.assert_model_compatible(model.stages[0], tmp_dir) | ||
|
||
def test_classifier_with_cross_validator(self): | ||
xgb_classifer = SparkXGBClassifier() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting idea, but how to control the saved file name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the
booster
string contain "\n" character ? If yes, when loading back (bysparkContext.textFile(model_load_path)
, each line will become one RDD element, and these lines might be split into multiple RDD partitions)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested, and It is always
part-00000
, seems there is a pattern for the generated file according to the task id since we only have 1 partition, so the id should be 00000There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's document the file name "part-00000" is the model json file.
and pls add a test to ensure the model json file does not contain
\n
character and document the reason.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked the code, the file name is defined by https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala#L225.
here the splitId is the TaskContext.partitionId(). In our case, there is only 1 partition, so the file name is "part-00000"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I know that. My point is can we customize the file name to make it more user-friendly.
Not a must though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's the internal behavior of pyspark, not sure if it's a good idea to rely on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to do that, it makes code hard to maintain, your current code is fine.