Skip to content

Commit

Permalink
Fix Horovod pyarrow IndexError: list index out of range (#3274)
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Nov 22, 2021
1 parent e1ddf3d commit 6e7464a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
67 changes: 67 additions & 0 deletions horovod/spark/common/util.py
Expand Up @@ -17,7 +17,9 @@

import contextlib
import os
import time

from multiprocessing.pool import ThreadPool
import pyarrow as pa
import numpy as np
import pyspark.sql.functions as f
Expand All @@ -30,6 +32,8 @@
except ImportError:
from pyspark.sql.types import from_arrow_type

from pyspark.sql import SparkSession

from horovod.runner.common.util import codec, host_hash as hh
from horovod.spark.common import cache, constants

Expand Down Expand Up @@ -539,6 +543,57 @@ def _train_val_split(df, validation):
return train_df, val_df, validation_ratio


_DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS = \
int(os.environ.get('DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS', '30'))


_DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS = \
float(os.environ.get('DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS', '0.1'))


def _wait_file_available_on_dbfs(store, url_list):
"""
On databricks runtime, Waiting about DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS seconds
(default 30 seconds) to make sure all files are available for reading.
This is because Databricks filesystem backend storage such as S3 which only providing
eventually consistency.
"""
# Import LocalStore here to avoid circular import
from horovod.spark.common.store import LocalStore
if isinstance(store, LocalStore):
return

if not is_databricks():
return

def wait_for_file(path):
end_time = time.time() + _DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS
while time.time() < end_time:
if store.exists(path):
return True
time.sleep(_DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS)
return False

if len(url_list) == 0:
raise ValueError('Input url_list argument is empty.')

pool = ThreadPool(min(len(url_list), 64))
try:
results = pool.map(wait_for_file, url_list)
failed_list = [url for url, result in zip(url_list, results) if not result]
if failed_list:
raise TimeoutError('Timeout while waiting for all files to appear at urls {failed_list}.'
.format(failed_list=','.join(failed_list)))
finally:
pool.close()
pool.join()


def _get_spark_df_saved_file_list(saved_path):
spark_session = SparkSession.builder.getOrCreate()
return list(spark_session.read.parquet(saved_path)._jdf.inputFiles())


def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
validation, sample_weight_col, compress_sparse,
num_partitions, num_processes, verbose):
Expand Down Expand Up @@ -590,6 +645,8 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(train_data_path)

saved_file_list = _get_spark_df_saved_file_list(train_data_path)

if val_df:
val_partitions = max(int(num_partitions * validation_ratio),
num_processes)
Expand All @@ -602,6 +659,16 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(val_data_path)

saved_file_list += _get_spark_df_saved_file_list(val_data_path)

try:
_wait_file_available_on_dbfs(store, saved_file_list)
except TimeoutError as e:
err_msg = 'Timeout while waiting for all parquet-store files to appear, Please ' \
'check whether these files were saved successfully when materializing ' \
'dataframe. Internal Error: {e}'.format(e=str(e))
raise RuntimeError(err_msg)

train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet(
store, label_columns, feature_columns, sample_weight_col, dataset_idx)

Expand Down
55 changes: 55 additions & 0 deletions test/utils/spark_common.py
Expand Up @@ -16,8 +16,11 @@
import contextlib
import os
import platform
import pytest
import stat
import sys
import threading
import time

from tempfile import TemporaryDirectory

Expand All @@ -30,6 +33,7 @@

from horovod.runner.common.util import secret
from horovod.spark.common.store import LocalStore
from horovod.spark.common.util import _wait_file_available_on_dbfs, _get_spark_df_saved_file_list
from horovod.spark.driver.driver_service import SparkDriverService, SparkDriverClient
from horovod.spark.task.task_service import SparkTaskService, SparkTaskClient

Expand Down Expand Up @@ -232,3 +236,54 @@ def create_mnist_data(spark):

def create_test_data_from_schema(spark, data, schema):
return spark.createDataFrame(data, schema=schema)


def test_wait_file_available_on_dbfs():
with tempdir() as d:
pq_dir = os.path.join(d, 'test_ev')
os.makedirs(pq_dir)
file1_path = os.path.join(pq_dir, 'file1')
file2_path = os.path.join(pq_dir, 'file2')
url1 = 'file://' + file1_path.replace(os.sep, '/')
url2 = 'file://' + file2_path.replace(os.sep, '/')

url_list = [url1, url2]

def create_file(p):
with open(p, 'w'):
pass

# 1. test all files exists.
create_file(file1_path)
create_file(file2_path)
_wait_file_available_on_dbfs(url_list)

# 2. test one file does not exists. Raise error.
os.remove(file2_path)
with pytest.raises(
RuntimeError,
match='Timeout while waiting for all parquet-store files to appear'
):
_wait_file_available_on_dbfs(url_list)

# 3. test one file accessible after 1 second.
def delay_create_file2():
time.sleep(1)
create_file(file2_path)

threading.Thread(target=delay_create_file2()).start()

_wait_file_available_on_dbfs(url_list)


def test_get_spark_df_input_files(spark):
with tempdir() as d:
pq_dir = os.path.join(d, 'test_spark_df_output')
with spark_session('test_get_spark_df_input_files') as spark:
spark.range(100).repartition(4).write.parquet(pq_dir)

pq_files = _get_spark_df_saved_file_list(pq_dir)
pq_files = sorted(pq_files)
assert len(pq_files) == 4
for i in range(4):
assert pq_files[i].startswith('part-0000' + str(i))

0 comments on commit 6e7464a

Please sign in to comment.