Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyspark] make the model saved by pyspark compatible #8219

Merged
merged 9 commits into from Sep 20, 2022

Conversation

wbo4958
Copy link
Contributor

@wbo4958 wbo4958 commented Sep 2, 2022

Users can't directly load the model using xgboost python package trained by pyspark. it requires much effort to do that, see #8186. This PR first saves the model in JSON format and then writes it to txt file. Then the user can easily load the model by

import xgboost as xgb
bst = xgb.Booster()

YOUR_MODEL_PATH="xxx"
bst.load_model(YOUR_MODEL_PATH/model/part-00000)

@wbo4958 wbo4958 changed the title [WIP][pyspark] make the model saved by pyspark compatible [pyspark] make the model saved by pyspark compatible Sep 5, 2022
@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 6, 2022

@trivialfis Could you help to check a python test failed

[2022-09-05T12:59:12.999Z] =================================== FAILURES ===================================

[2022-09-05T12:59:12.999Z] ____________________________ test_gpu_data_iterator ____________________________

[2022-09-05T12:59:12.999Z] 

[2022-09-05T12:59:12.999Z] cls = <class '_pytest.runner.CallInfo'>

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 8, 2022

@WeichenXu123 @trivialfis Could you help to review this PR?

@trivialfis
Copy link
Member

Will look into it tomorrow.

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 13, 2022

@WeichenXu123 @trivialfis Any feedback for this PR?

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 14, 2022

Hi @WeichenXu123 @trivialfis, could you help to review it?

@trivialfis
Copy link
Member

trivialfis commented Sep 14, 2022

Can we document the function get_booster(self) and let the user extract the booster? I think it's easier.

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 14, 2022

Yeah, We can, but the issue will be the same with the one JVM package previously encountered. Most users dump the mode by the spark way, they may don't like to do another get_booster.save_model again. So the model may be moved to another machine (or another team without any knowledge of spark) without spark cluster deployed since users may just want to load the model with python package and do some prediction. in that case, it's really un-convenient for users. This PR is supposed not to introduce any side effects, so I think it's ok to be merged.

@@ -21,34 +21,28 @@ def _get_or_create_tmp_dir():
return xgb_tmp_dir


def serialize_xgb_model(model):
def dump_model_to_json_file(model) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the term save. Dump has a specific meaning in XGBoost's code base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

).write.parquet(model_save_path)
model_save_path = os.path.join(path, "model")
xgb_model_file = dump_model_to_json_file(xgb_model)
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some " \ " " in the json file which can't be loaded by xgboost. Do you want to check more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis No need anymore, I just found another way to do it.

xgb_model_file = save_model_to_json_file(xgb_model)
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")`
# can't be loaded by XGBoost directly.
_get_spark_session().read.text(xgb_model_file).write.text(model_save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_spark_session().read.text(xgb_model_file).

This line is not correct.
spark.read.text(path) the path must be a distributed file system path which all spark executor can access.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use distributed FS API to copy local file xgb_model_file into the model saved path (a hadoop FS path)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, right. you're correct, @WeichenXu123 Good findings. Could you point me to what is the "distributed FS API"? Really appreciate it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use this:
https://arrow.apache.org/docs/python/generated/pyarrow.fs.HadoopFileSystem.html

But, this does not support DBFS (databricks filesystem), we need support databricks case as well.
Databricks mount dbfs:/xxx/xxx to local file system /dbfs/xxx/xxx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example code in the PR description

import xgboost as xgb
bst = xgb.Booster()

# Basically, YOUR_MODEL_PATH should be like "xxxx/model/xxx.txt"
YOUR_MODEL_PATH="xxx"
bst.load_model(YOUR_MODEL_PATH)

seems does not wok ? If the path is a distributed FS path ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123, I use the RDD to save the text file, it should work with all kinds of hadoop-compatible FS..

@WeichenXu123
Copy link
Contributor

Do we really need this PR ?
User can load pyspark model and then call pyspark_model.booster to get the raw booster model.

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 15, 2022

Do we really need this PR ? User can load pyspark model and then call pyspark_model.booster to get the raw booster model.

Yeah, guess the scenario, the data scientist who does not know spark gets a model saved by xgboost-spark and wants to load it by xgboost python package, what does he/she can do?

Although we can doc it, trust me, not everyone would like to read the whole doc carefully. Previously, XGBoost-JVM has the same issue, so I changed that.

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 16, 2022

Hi @hcho3, what does "Pending" mean for pipelines like xgboost-ci/pr?

Comment on lines +205 to +206
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea, but how to control the saved file name ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the booster string contain "\n" character ? If yes, when loading back (by sparkContext.textFile(model_load_path), each line will become one RDD element, and these lines might be split into multiple RDD partitions)

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested, and It is always part-00000, seems there is a pattern for the generated file according to the task id since we only have 1 partition, so the id should be 00000

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document the file name "part-00000" is the model json file.

and pls add a test to ensure the model json file does not contain \n character and document the reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the code, the file name is defined by https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala#L225.

  override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = {
    val numfmt = NumberFormat.getInstance(Locale.US)
    numfmt.setMinimumIntegerDigits(5)
    numfmt.setGroupingUsed(false)

    val outputName = "part-" + numfmt.format(splitId)
    val path = FileOutputFormat.getOutputPath(getConf)
    val fs: FileSystem = {
      if (path != null) {
        path.getFileSystem(getConf)
      } else {
        // scalastyle:off FileSystemGet
        FileSystem.get(getConf)
        // scalastyle:on FileSystemGet
      }
    }
...

here the splitId is the TaskContext.partitionId(). In our case, there is only 1 partition, so the file name is "part-00000"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I know that. My point is can we customize the file name to make it more user-friendly.
Not a must though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the internal behavior of pyspark, not sure if it's a good idea to rely on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.

No need to do that, it makes code hard to maintain, your current code is fine.

@hcho3
Copy link
Collaborator

hcho3 commented Sep 16, 2022

what does "Pending" mean for pipelines like xgboost-ci/pr?

The CI pipeline doesn't run until one of the admins (like me) give approval. We do this to save the CI costs.

bst = xgb.Booster()
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0]
bst.load_model(path)
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test to assert model file does not include \n char.

Copy link
Contributor Author

@wbo4958 wbo4958 Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) or bst.load_model(path) will fail.

@trivialfis
Copy link
Member

I will leave the approval to @WeichenXu123 . Could you please add document as well? About the get_booster and your workaround for the model serialization.

@wbo4958
Copy link
Contributor Author

wbo4958 commented Sep 18, 2022

Sure, I will add the doc in the following PR along with how to leverage RAPIDS to accelerate xgboost pyspark.

@trivialfis @hcho3 could you trigger the CI of this PR

@trivialfis trivialfis merged commit 4f42aa5 into dmlc:master Sep 20, 2022
@wbo4958 wbo4958 deleted the model-format branch April 23, 2024 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants