Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make checkpoint name optional so that user can save to h5 format. #3411

Merged
merged 2 commits into from Mar 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 8 additions & 6 deletions horovod/spark/common/store.py
Expand Up @@ -168,14 +168,16 @@ class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""

def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None,
runs_path=None, save_runs=True, storage_options=None, **kwargs):
runs_path=None, save_runs=True, storage_options=None, checkpoint_filename=None,
**kwargs):
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
self.storage_options = storage_options
self.checkpoint_filename = checkpoint_filename if checkpoint_filename else 'checkpoint'
super().__init__()

def exists(self, path):
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_logs_path(self, run_id):
if self._save_runs else None

def get_checkpoint_filename(self):
return 'checkpoint'
return self.checkpoint_filename

def get_logs_subdir(self):
return 'logs'
Expand Down Expand Up @@ -303,7 +305,7 @@ def __init__(self, prefix_path, *args, **kwargs):
self.storage_options = kwargs['storage_options'] if 'storage_options' in kwargs else {}
self.prefix_path = prefix_path
self._fs, self.protocol = self._get_fs_and_protocol()
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
params = dict((k, kwargs[k]) for k in std_params if k in kwargs)
super().__init__(prefix_path, *args, **params)

Expand All @@ -317,9 +319,9 @@ def fn(local_run_path):
return fn

def copy(self, lpath, rpath, recursive=False, callback=_DEFAULT_CALLBACK,**kwargs):
"""
"""
This method copies the contents of the local source directory to the target directory.
This is different from the fsspec's put() because it does not copy the source folder
This is different from the fsspec's put() because it does not copy the source folder
to the target directory in the case when target directory already exists.
"""

Expand Down Expand Up @@ -519,7 +521,7 @@ def _check_url(self, url, prefix, path):

if not path:
raise ValueError('Failed to parse path from URL: {}'.format(url))

def get_localized_path(self, path):
if self.matches(path):
return path[len(self._url_prefix):]
Expand Down