Skip to content

Commit

Permalink
Spark/Lightning: don't add checkpoint callback by default
Browse files Browse the repository at this point in the history
Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Oct 6, 2021
1 parent 062aaa0 commit ab2d9a7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
15 changes: 2 additions & 13 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -249,8 +246,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None):
profiler=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -264,8 +260,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None)
profiler=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down
24 changes: 13 additions & 11 deletions horovod/spark/lightning/remote.py
Expand Up @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -99,6 +98,7 @@ def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
hvd.init()
_checkpoint_callback = None

with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
Expand Down Expand Up @@ -127,16 +127,15 @@ def train(serialized_model):

# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
for i, cb in enumerate(callbacks):
if isinstance(cb, ModelCheckpoint):
_checkpoint_callback = cb
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
break
if _checkpoint_callback:
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)
callbacks.pop(i)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
Expand Down Expand Up @@ -224,7 +223,10 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
remote_store.sync(logs_path)

# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
if _checkpoint_callback:
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
else:
best_model = model
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

Expand Down

0 comments on commit ab2d9a7

Please sign in to comment.