Skip to content

Commit

Permalink
Make checkpoint name optional so that user can save to h5 format. (#3411
Browse files Browse the repository at this point in the history
)

Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed Mar 1, 2022
1 parent b553974 commit 7bf9b04
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions horovod/spark/common/store.py
Expand Up @@ -168,14 +168,16 @@ class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""

def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None,
runs_path=None, save_runs=True, storage_options=None, **kwargs):
runs_path=None, save_runs=True, storage_options=None, checkpoint_filename=None,
**kwargs):
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
self.storage_options = storage_options
self.checkpoint_filename = checkpoint_filename if checkpoint_filename else 'checkpoint'
super().__init__()

def exists(self, path):
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_logs_path(self, run_id):
if self._save_runs else None

def get_checkpoint_filename(self):
return 'checkpoint'
return self.checkpoint_filename

def get_logs_subdir(self):
return 'logs'
Expand Down Expand Up @@ -303,7 +305,7 @@ def __init__(self, prefix_path, *args, **kwargs):
self.storage_options = kwargs['storage_options'] if 'storage_options' in kwargs else {}
self.prefix_path = prefix_path
self._fs, self.protocol = self._get_fs_and_protocol()
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
params = dict((k, kwargs[k]) for k in std_params if k in kwargs)
super().__init__(prefix_path, *args, **params)

Expand All @@ -317,9 +319,9 @@ def fn(local_run_path):
return fn

def copy(self, lpath, rpath, recursive=False, callback=_DEFAULT_CALLBACK,**kwargs):
"""
"""
This method copies the contents of the local source directory to the target directory.
This is different from the fsspec's put() because it does not copy the source folder
This is different from the fsspec's put() because it does not copy the source folder
to the target directory in the case when target directory already exists.
"""

Expand Down Expand Up @@ -519,7 +521,7 @@ def _check_url(self, url, prefix, path):

if not path:
raise ValueError('Failed to parse path from URL: {}'.format(url))

def get_localized_path(self, path):
if self.matches(path):
return path[len(self._url_prefix):]
Expand Down

0 comments on commit 7bf9b04

Please sign in to comment.