Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Horovod pyarrow IndexError: list index out of range #3274

Merged
merged 11 commits into from Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions horovod/spark/common/util.py
Expand Up @@ -17,7 +17,9 @@

import contextlib
import os
import time

from multiprocessing.pool import ThreadPool
import pyarrow as pa
import numpy as np
import pyspark.sql.functions as f
Expand All @@ -30,6 +32,8 @@
except ImportError:
from pyspark.sql.types import from_arrow_type

from pyspark.sql import SparkSession

from horovod.runner.common.util import codec, host_hash as hh
from horovod.spark.common import cache, constants

Expand Down Expand Up @@ -539,6 +543,46 @@ def _train_val_split(df, validation):
return train_df, val_df, validation_ratio


_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS = \
int(os.environ.get('FILE_AVAILABILITY_WAIT_TIMEOUT_SECS', '30'))


def _wait_file_available(store, url_list):
"""Waiting about _FILE_AVAILABILITY_WAIT_TIMEOUT_SECS seconds (default 30 seconds) to make sure
all files are available for reading. This is useful in some filesystems, such as S3 which only
providing eventually consistency.
"""
# Import LocalStore here to avoid circular import
from horovod.spark.common.store import LocalStore
if isinstance(store, LocalStore):
return

def wait_for_file(path):
end_time = time.time() + _FILE_AVAILABILITY_WAIT_TIMEOUT_SECS
while time.time() < end_time:
if store.exists(path):
return True
time.sleep(0.1)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
return False

pool = ThreadPool(min(len(url_list), 64))
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
try:
results = pool.map(wait_for_file, url_list)
failed_list = [url for url, result in zip(url_list, results) if not result]
if failed_list:
raise RuntimeError('Timeout while waiting for all parquet-store files to appear at urls {failed_list},'
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
'Please check whether these files were saved successfully when materializing dataframe.'
.format(failed_list=','.join(failed_list)))
finally:
pool.close()
pool.join()


def _get_spark_df_saved_file_list(saved_path):
spark_session = SparkSession.builder.getOrCreate()
return list(spark_session.read.parquet(saved_path)._jdf.inputFiles())
EnricoMi marked this conversation as resolved.
Show resolved Hide resolved


def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
validation, sample_weight_col, compress_sparse,
num_partitions, num_processes, verbose):
Expand Down Expand Up @@ -590,6 +634,8 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(train_data_path)

saved_file_list = _get_spark_df_saved_file_list(train_data_path)

if val_df:
val_partitions = max(int(num_partitions * validation_ratio),
num_processes)
Expand All @@ -602,6 +648,10 @@ def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
.mode('overwrite') \
.parquet(val_data_path)

saved_file_list += _get_spark_df_saved_file_list(val_data_path)

_wait_file_available(store, saved_file_list)

train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet(
store, label_columns, feature_columns, sample_weight_col, dataset_idx)

Expand Down
55 changes: 55 additions & 0 deletions test/utils/spark_common.py
Expand Up @@ -16,8 +16,11 @@
import contextlib
import os
import platform
import pytest
import stat
import sys
import threading
import time

from tempfile import TemporaryDirectory

Expand All @@ -30,6 +33,7 @@

from horovod.runner.common.util import secret
from horovod.spark.common.store import LocalStore
from horovod.spark.common.util import _wait_file_available
from horovod.spark.driver.driver_service import SparkDriverService, SparkDriverClient
from horovod.spark.task.task_service import SparkTaskService, SparkTaskClient

Expand Down Expand Up @@ -232,3 +236,54 @@ def create_mnist_data(spark):

def create_test_data_from_schema(spark, data, schema):
return spark.createDataFrame(data, schema=schema)


def test_wait_file_available():
with tempdir() as d:
pq_dir = os.path.join(d, 'test_ev')
os.makedirs(pq_dir)
file1_path = os.path.join(pq_dir, 'file1')
file2_path = os.path.join(pq_dir, 'file2')
url1 = 'file://' + file1_path.replace(os.sep, '/')
url2 = 'file://' + file2_path.replace(os.sep, '/')

url_list = [url1, url2]

def create_file(p):
with open(p, 'w'):
pass

# 1. test all files exists.
create_file(file1_path)
create_file(file2_path)
_wait_file_available(url_list)

# 2. test one file does not exists. Raise error.
os.remove(file2_path)
with pytest.raises(
RuntimeError,
match='Timeout while waiting for all parquet-store files to appear at urls'
):
_wait_file_available(url_list)

# 3. test one file accessible after 1 second.
def delay_create_file2():
time.sleep(1)
create_file(file2_path)

threading.Thread(target=delay_create_file2()).start()

_wait_file_available(url_list)


def test_get_spark_df_input_files():
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
with tempdir() as d:
pq_dir = os.path.join(d, 'test_spark_df_output')
with spark_session('test_get_spark_df_input_files') as spark:
spark.range(100).repartition(4).write.parquet(pq_dir)

pq_files = sorted([filename for filename in os.listdir(pq_dir)
if filename.endswith('.parquet')])
assert len(pq_files) == 4
for i in range(4):
assert pq_files[i].startswith('part-0000' + str(i))