Skip to content

Commit

Permalink
Spark/Lightning: don't overwrite model with checkpoint by default (#3201
Browse files Browse the repository at this point in the history
)

Lightning estimator saves model by default if there is no specified
checkpoint callback. However, model is not overwritten with checkpoint
file in that case.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Oct 7, 2021
1 parent 81340ee commit dadca53
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 24 deletions.
7 changes: 4 additions & 3 deletions examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -177,14 +177,15 @@ def on_train_end(self, trainer, model):

# added EarlyStopping and ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(dirpath=args.work_dir))
callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
save_top_k=1, verbose=True))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.00,
min_delta=0.001,
patience=3,
verbose=True,
mode='max'))
mode='min'))

torch_estimator = hvd.TorchEstimator(backend=backend,
store=store,
Expand Down
15 changes: 2 additions & 13 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -249,8 +246,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None):
profiler=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -264,8 +260,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None)
profiler=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down
28 changes: 20 additions & 8 deletions horovod/spark/lightning/remote.py
Expand Up @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -99,6 +98,8 @@ def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
hvd.init()
_checkpoint_callback = None
require_checkpoint = False

with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
Expand All @@ -115,6 +116,7 @@ def train(serialized_model):
elif isinstance(logger, CometLogger) and logger._experiment_key is None:
# Resume logger experiment key if passed correctly from CPU.
train_logger = CometLogger(
save_dir=logs_path,
api_key=logger.api_key,
experiment_key=logger_experiment_key,
)
Expand All @@ -123,20 +125,24 @@ def train(serialized_model):
else:
# use logger passed in.
train_logger = logger
train_logger.save_dir = logs_path
print(f"Setup logger: Using logger passed from estimator: {train_logger}")

# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
cb.dir_path = ckpt_dir
cb.filename = ckpt_filename
_checkpoint_callback = cb
require_checkpoint = True
break
if not _checkpoint_callback:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
Expand Down Expand Up @@ -224,7 +230,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
remote_store.sync(logs_path)

# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
if require_checkpoint:
if verbose:
print("load from checkpoint best model path:",
_checkpoint_callback.best_model_path)
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
else:
best_model = model
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

Expand Down

0 comments on commit dadca53

Please sign in to comment.