Skip to content

Commit

Permalink
Spark/Lightning: don't overwrite model with checkpoint by default
Browse files Browse the repository at this point in the history
Lightning estimator saves model by default if there is no specified
checkpoint callback. However, model is not overwritten with checkpoint
file in that case.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Oct 6, 2021
1 parent 81340ee commit 38bcb88
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 24 deletions.
7 changes: 4 additions & 3 deletions examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -177,14 +177,15 @@ def on_train_end(self, trainer, model):

# added EarlyStopping and ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(dirpath=args.work_dir))
callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
save_top_k=1, verbose=True))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.00,
min_delta=0.001,
patience=3,
verbose=True,
mode='max'))
mode='min'))

torch_estimator = hvd.TorchEstimator(backend=backend,
store=store,
Expand Down
15 changes: 2 additions & 13 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -249,8 +246,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None):
profiler=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -264,8 +260,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None)
profiler=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down
28 changes: 20 additions & 8 deletions horovod/spark/lightning/remote.py
Expand Up @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -99,6 +98,8 @@ def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
hvd.init()
_checkpoint_callback = None
require_checkpoint = False

with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
Expand All @@ -115,6 +116,7 @@ def train(serialized_model):
elif isinstance(logger, CometLogger) and logger._experiment_key is None:
# Resume logger experiment key if passed correctly from CPU.
train_logger = CometLogger(
save_dir=logs_path,
api_key=logger.api_key,
experiment_key=logger_experiment_key,
)
Expand All @@ -123,20 +125,24 @@ def train(serialized_model):
else:
# use logger passed in.
train_logger = logger
train_logger.save_dir = logs_path
print(f"Setup logger: Using logger passed from estimator: {train_logger}")

# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
cb.dir_path = ckpt_dir
cb.filename = ckpt_filename
_checkpoint_callback = cb
require_checkpoint = True
break
if not _checkpoint_callback:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
Expand Down Expand Up @@ -224,7 +230,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
remote_store.sync(logs_path)

# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
if require_checkpoint:
if verbose:
print("load from checkpoint best model path:",
_checkpoint_callback.best_model_path)
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
else:
best_model = model
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

Expand Down

0 comments on commit 38bcb88

Please sign in to comment.