Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed Nov 12, 2021
1 parent de0b894 commit 1c69a9a
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions horovod/spark/common/util.py
Expand Up @@ -32,6 +32,8 @@
except ImportError:
from pyspark.sql.types import from_arrow_type

from pyspark.sql import SparkSession

from horovod.runner.common.util import codec, host_hash as hh
from horovod.spark.common import cache, constants

Expand Down Expand Up @@ -576,6 +578,11 @@ def wait_for_file(path):
pool.join()


def get_spark_df_saved_file_list(saved_path):
spark_session = SparkSession.builder.getOrCreate()
return list(spark_session.read.parquet(saved_path)._jdf.inputFiles())


def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
validation, sample_weight_col, compress_sparse,
num_partitions, num_processes, verbose):
Expand Down Expand Up @@ -627,6 +634,8 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(train_data_path)

saved_file_list = get_spark_df_saved_file_list(train_data_path)

if val_df:
val_partitions = max(int(num_partitions * validation_ratio),
num_processes)
Expand All @@ -639,9 +648,7 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(val_data_path)

saved_file_list = list(train_df._jdf.inputFiles())
if val_df:
saved_file_list += list(val_df._jdf.inputFiles())
saved_file_list += get_spark_df_saved_file_list(val_data_path)

_wait_file_available(store, saved_file_list)

Expand Down

0 comments on commit 1c69a9a

Please sign in to comment.