Skip to content

Commit

Permalink
Fixed the inconsistency issue with put
Browse files Browse the repository at this point in the history
Signed-off-by: Kamal Sharma <kamalbhardwaj020@gmail.com>
  • Loading branch information
kamalsharma2 committed Jan 13, 2022
1 parent 69c3329 commit 05149d2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
32 changes: 30 additions & 2 deletions horovod/spark/common/store.py
Expand Up @@ -30,6 +30,7 @@
import fsspec
from fsspec.core import split_protocol
from fsspec.utils import update_storage_options
from fsspec.callbacks import _DEFAULT_CALLBACK

from horovod.spark.common.util import is_databricks, host_hash

Expand Down Expand Up @@ -248,7 +249,7 @@ def get_run_path(self, run_id):
return os.path.join(self.get_runs_path(), run_id)

def get_checkpoint_path(self, run_id):
return os.path.join(self.get_run_path(run_id), self.get_checkpoint_filename()) \
return self.get_run_path(run_id) \
if self._save_runs else None

def get_checkpoints(self, run_id, suffix='.ckpt'):
Expand All @@ -266,6 +267,9 @@ def get_checkpoint_filename(self):
def get_logs_subdir(self):
return 'logs'

def get_wildcard(self):
return '*'

def _get_full_path_or_default(self, path, default_key):
if path is not None:
return self.get_full_path(path)
Expand Down Expand Up @@ -311,10 +315,34 @@ def sync_fn(self, run_id):

def fn(local_run_path):
print(f"Syncing dir {local_run_path} to dir {run_path}")
self.fs.put(local_run_path, run_path, recursive=True, overwrite=True)
if self.fs.exists(run_path):
local_run_path = os.path.join(local_run_path,self.get_wildcard())
self.copy(local_run_path, run_path, recursive=True, overwrite=True)

return fn

def copy(self,lpath,rpath,recursive=False,callback=_DEFAULT_CALLBACK,**kwargs):
from fsspec.implementations.local import LocalFileSystem, make_path_posix
from fsspec.utils import other_paths

rpath = (
self.fs._strip_protocol(rpath)
if isinstance(rpath, str)
else [self.fs._strip_protocol(p) for p in rpath]
)
if isinstance(lpath, str):
lpath = make_path_posix(lpath)
fs = LocalFileSystem()
lpaths = fs.expand_path(lpath, recursive=recursive)
rpaths = other_paths(
lpaths, rpath, exists=isinstance(rpath, str) and self.fs.isdir(rpath)
)

callback.set_size(len(rpaths))
for lpath, rpath in callback.wrap(zip(lpaths, rpaths)):
callback.branch(lpath, rpath, kwargs)
self.fs.put_file(lpath, rpath, **kwargs)

def get_filesystem(self):
return self.fs

Expand Down
4 changes: 4 additions & 0 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -463,6 +463,10 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row
def _read_checkpoint(self, run_id):
store = self.getStore()
checkpoints = store.get_checkpoints(run_id, suffix='.ckpt')

if not checkpoints:
return None

last_ckpt_path = checkpoints[-1]

if self.getVerbose():
Expand Down
6 changes: 5 additions & 1 deletion horovod/spark/torch/estimator.py
Expand Up @@ -19,6 +19,7 @@
import io
import numbers
import time
import os

from pyspark import keyword_only
from pyspark.ml.param.shared import Param, Params, TypeConverters
Expand Down Expand Up @@ -287,7 +288,10 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row

def _load_checkpoint(self, run_id):
store = self.getStore()
last_ckpt_path = store.get_checkpoint_path(run_id)
last_ckpt_path = os.path.join(store.get_checkpoint_path(run_id),store.get_checkpoint_filename())

if not store.fs.exists(last_ckpt_path):
return None

if self.getVerbose():
print('Resuming training from last checkpoint: {}'.format(last_ckpt_path))
Expand Down
8 changes: 1 addition & 7 deletions test/integration/test_spark_lightning.py
Expand Up @@ -215,11 +215,8 @@ def test_legacy_fit_model(self):
assert len(pred) == 1
assert pred.dtype == torch.float32

# TODO: Add this test back after checkpoint call back is supported
def test_restore_from_checkpoint(self):
self.skipTest('There is a deadlock bug for checkpoint call back. ' +
'Will add this test back when it is solved.')


model = create_xor_model()

with spark_session('test_restore_from_checkpoint') as spark:
Expand Down Expand Up @@ -253,10 +250,7 @@ def test_restore_from_checkpoint(self):
torch_estimator.fit(df)
torch_estimator._read_checkpoint.assert_called()

# TODO: Add this test back after checkpoint call back is supported
def test_legacy_restore_from_checkpoint(self):
self.skipTest('There is a deadlock bug for checkpoint call back. ' +
'Will add this test back when it is solved.')

model = create_legacy_xor_model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand Down

0 comments on commit 05149d2

Please sign in to comment.