diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py index 95a96051a4e6..6b050a468357 100644 --- a/python-package/xgboost/spark/model.py +++ b/python-package/xgboost/spark/model.py @@ -21,34 +21,12 @@ def _get_or_create_tmp_dir(): return xgb_tmp_dir -def serialize_xgb_model(model): +def deserialize_xgb_model(model_string, xgb_model_creator): """ - Serialize the input model to a string. - - Parameters - ---------- - model: - an xgboost.XGBModel instance, such as - xgboost.XGBClassifier or xgboost.XGBRegressor instance - """ - # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - model.save_model(tmp_file_name) - with open(tmp_file_name, "r", encoding="utf-8") as f: - ser_model_string = f.read() - return ser_model_string - - -def deserialize_xgb_model(ser_model_string, xgb_model_creator): - """ - Deserialize an xgboost.XGBModel instance from the input ser_model_string. + Deserialize an xgboost.XGBModel instance from the input model_string. """ xgb_model = xgb_model_creator() - # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - with open(tmp_file_name, "w", encoding="utf-8") as f: - f.write(ser_model_string) - xgb_model.load_model(tmp_file_name) + xgb_model.load_model(bytearray(model_string.encode("utf-8"))) return xgb_model @@ -222,11 +200,11 @@ def saveImpl(self, path): """ xgb_model = self.instance._xgb_sklearn_model _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) - model_save_path = os.path.join(path, "model.json") - ser_xgb_model = serialize_xgb_model(xgb_model) - _get_spark_session().createDataFrame( - [(ser_xgb_model,)], ["xgb_sklearn_model"] - ).write.parquet(model_save_path) + model_save_path = os.path.join(path, "model") + booster = xgb_model.get_booster().save_raw("json").decode("utf-8") + _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile( + model_save_path + ) class SparkXGBModelReader(MLReader): @@ -252,13 +230,10 @@ def load(self, path): xgb_sklearn_params = py_model._gen_xgb_params_dict( gen_xgb_sklearn_estimator_param=True ) - model_load_path = os.path.join(path, "model.json") + model_load_path = os.path.join(path, "model") ser_xgb_model = ( - _get_spark_session() - .read.parquet(model_load_path) - .collect()[0] - .xgb_sklearn_model + _get_spark_session().sparkContext.textFile(model_load_path).collect()[0] ) def create_xgb_model(): diff --git a/tests/python/test_spark/test_spark_local.py b/tests/python/test_spark/test_spark_local.py index 58c313ea043f..0e5bded06c96 100644 --- a/tests/python/test_spark/test_spark_local.py +++ b/tests/python/test_spark/test_spark_local.py @@ -1,3 +1,4 @@ +import glob import logging import random import sys @@ -7,6 +8,8 @@ import pytest import testing as tm +import xgboost as xgb + if tm.no_spark()["condition"]: pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True) if sys.platform.startswith("win") or sys.platform.startswith("darwin"): @@ -30,7 +33,7 @@ ) from xgboost.spark.core import _non_booster_params -from xgboost import XGBClassifier, XGBRegressor +from xgboost import XGBClassifier, XGBModel, XGBRegressor from .utils import SparkTestCase @@ -62,7 +65,12 @@ def setUp(self): # >>> reg2.fit(X, y) # >>> reg2.predict(X, ntree_limit=5) # array([0.22185266, 0.77814734], dtype=float32) - self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} + self.reg_params = { + "max_depth": 5, + "n_estimators": 10, + "ntree_limit": 5, + "max_bin": 9, + } self.reg_df_train = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0), @@ -427,6 +435,12 @@ def setUp(self): def get_local_tmp_dir(self): return self.tempdir + str(uuid.uuid4()) + def assert_model_compatible(self, model: XGBModel, model_path: str): + bst = xgb.Booster() + path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] + bst.load_model(path) + self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) + def test_regressor_params_basic(self): py_reg = SparkXGBRegressor() self.assertTrue(hasattr(py_reg, "n_estimators")) @@ -591,7 +605,8 @@ def test_classifier_with_params(self): ) def test_regressor_model_save_load(self): - path = "file:" + self.get_local_tmp_dir() + tmp_dir = self.get_local_tmp_dir() + path = "file:" + tmp_dir regressor = SparkXGBRegressor(**self.reg_params) model = regressor.fit(self.reg_df_train) model.save(path) @@ -611,8 +626,11 @@ def test_regressor_model_save_load(self): with self.assertRaisesRegex(AssertionError, "Expected class name"): SparkXGBClassifierModel.load(path) + self.assert_model_compatible(model, tmp_dir) + def test_classifier_model_save_load(self): - path = "file:" + self.get_local_tmp_dir() + tmp_dir = self.get_local_tmp_dir() + path = "file:" + tmp_dir regressor = SparkXGBClassifier(**self.cls_params) model = regressor.fit(self.cls_df_train) model.save(path) @@ -632,12 +650,15 @@ def test_classifier_model_save_load(self): with self.assertRaisesRegex(AssertionError, "Expected class name"): SparkXGBRegressorModel.load(path) + self.assert_model_compatible(model, tmp_dir) + @staticmethod def _get_params_map(params_kv, estimator): return {getattr(estimator, k): v for k, v in params_kv.items()} def test_regressor_model_pipeline_save_load(self): - path = "file:" + self.get_local_tmp_dir() + tmp_dir = self.get_local_tmp_dir() + path = "file:" + tmp_dir regressor = SparkXGBRegressor() pipeline = Pipeline(stages=[regressor]) pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor)) @@ -655,9 +676,11 @@ def test_regressor_model_pipeline_save_load(self): row.prediction, row.expected_prediction_with_params, atol=1e-3 ) ) + self.assert_model_compatible(model.stages[0], tmp_dir) def test_classifier_model_pipeline_save_load(self): - path = "file:" + self.get_local_tmp_dir() + tmp_dir = self.get_local_tmp_dir() + path = "file:" + tmp_dir classifier = SparkXGBClassifier() pipeline = Pipeline(stages=[classifier]) pipeline = pipeline.copy( @@ -677,6 +700,7 @@ def test_classifier_model_pipeline_save_load(self): row.probability, row.expected_probability_with_params, atol=1e-3 ) ) + self.assert_model_compatible(model.stages[0], tmp_dir) def test_classifier_with_cross_validator(self): xgb_classifer = SparkXGBClassifier()