Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Horovod pyarrow IndexError: list index out of range #3274

Merged
merged 11 commits into from Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 67 additions & 0 deletions horovod/spark/common/util.py
Expand Up @@ -17,7 +17,9 @@

import contextlib
import os
import time

from multiprocessing.pool import ThreadPool
import pyarrow as pa
import numpy as np
import pyspark.sql.functions as f
Expand All @@ -30,6 +32,8 @@
except ImportError:
from pyspark.sql.types import from_arrow_type

from pyspark.sql import SparkSession

from horovod.runner.common.util import codec, host_hash as hh
from horovod.spark.common import cache, constants

Expand Down Expand Up @@ -539,6 +543,57 @@ def _train_val_split(df, validation):
return train_df, val_df, validation_ratio


_DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS = \
int(os.environ.get('DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS', '30'))


_DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS = \
float(os.environ.get('DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS', '0.1'))


def _wait_file_available_on_dbfs(store, url_list):
"""
On databricks runtime, Waiting about DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS seconds
(default 30 seconds) to make sure all files are available for reading.
This is because Databricks filesystem backend storage such as S3 which only providing
eventually consistency.
"""
# Import LocalStore here to avoid circular import
from horovod.spark.common.store import LocalStore
if isinstance(store, LocalStore):
return

if not is_databricks():
return

def wait_for_file(path):
end_time = time.time() + _DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS
while time.time() < end_time:
if store.exists(path):
return True
time.sleep(_DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS)
return False

if len(url_list) == 0:
raise ValueError('Input url_list argument is empty.')

pool = ThreadPool(min(len(url_list), 64))
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
try:
results = pool.map(wait_for_file, url_list)
failed_list = [url for url, result in zip(url_list, results) if not result]
if failed_list:
raise TimeoutError('Timeout while waiting for all files to appear at urls {failed_list}.'
.format(failed_list=','.join(failed_list)))
finally:
pool.close()
pool.join()


def _get_spark_df_saved_file_list(saved_path):
spark_session = SparkSession.builder.getOrCreate()
return list(spark_session.read.parquet(saved_path)._jdf.inputFiles())
EnricoMi marked this conversation as resolved.
Show resolved Hide resolved


def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
validation, sample_weight_col, compress_sparse,
num_partitions, num_processes, verbose):
Expand Down Expand Up @@ -590,6 +645,8 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(train_data_path)

saved_file_list = _get_spark_df_saved_file_list(train_data_path)

if val_df:
val_partitions = max(int(num_partitions * validation_ratio),
num_processes)
Expand All @@ -602,6 +659,16 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(val_data_path)

saved_file_list += _get_spark_df_saved_file_list(val_data_path)

try:
_wait_file_available_on_dbfs(store, saved_file_list)
except TimeoutError as e:
err_msg = 'Timeout while waiting for all parquet-store files to appear, Please ' \
'check whether these files were saved successfully when materializing ' \
'dataframe. Internal Error: {e}'.format(e=str(e))
raise RuntimeError(err_msg)

train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet(
store, label_columns, feature_columns, sample_weight_col, dataset_idx)

Expand Down
55 changes: 55 additions & 0 deletions test/utils/spark_common.py
Expand Up @@ -16,8 +16,11 @@
import contextlib
import os
import platform
import pytest
import stat
import sys
import threading
import time

from tempfile import TemporaryDirectory

Expand All @@ -30,6 +33,7 @@

from horovod.runner.common.util import secret
from horovod.spark.common.store import LocalStore
from horovod.spark.common.util import _wait_file_available_on_dbfs, _get_spark_df_saved_file_list
from horovod.spark.driver.driver_service import SparkDriverService, SparkDriverClient
from horovod.spark.task.task_service import SparkTaskService, SparkTaskClient

Expand Down Expand Up @@ -232,3 +236,54 @@ def create_mnist_data(spark):

def create_test_data_from_schema(spark, data, schema):
return spark.createDataFrame(data, schema=schema)


def test_wait_file_available_on_dbfs():
with tempdir() as d:
pq_dir = os.path.join(d, 'test_ev')
os.makedirs(pq_dir)
file1_path = os.path.join(pq_dir, 'file1')
file2_path = os.path.join(pq_dir, 'file2')
url1 = 'file://' + file1_path.replace(os.sep, '/')
url2 = 'file://' + file2_path.replace(os.sep, '/')

url_list = [url1, url2]

def create_file(p):
with open(p, 'w'):
pass

# 1. test all files exists.
create_file(file1_path)
create_file(file2_path)
_wait_file_available_on_dbfs(url_list)

# 2. test one file does not exists. Raise error.
os.remove(file2_path)
with pytest.raises(
RuntimeError,
match='Timeout while waiting for all parquet-store files to appear'
):
_wait_file_available_on_dbfs(url_list)

# 3. test one file accessible after 1 second.
def delay_create_file2():
time.sleep(1)
create_file(file2_path)

threading.Thread(target=delay_create_file2()).start()

_wait_file_available_on_dbfs(url_list)


def test_get_spark_df_input_files(spark):
with tempdir() as d:
pq_dir = os.path.join(d, 'test_spark_df_output')
with spark_session('test_get_spark_df_input_files') as spark:
spark.range(100).repartition(4).write.parquet(pq_dir)

pq_files = _get_spark_df_saved_file_list(pq_dir)
pq_files = sorted(pq_files)
assert len(pq_files) == 4
for i in range(4):
assert pq_files[i].startswith('part-0000' + str(i))