Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark/Lightning: don't add checkpoint callback by default #3201

Merged
merged 1 commit into from Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -177,14 +177,15 @@ def on_train_end(self, trainer, model):

# added EarlyStopping and ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(dirpath=args.work_dir))
callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
save_top_k=1, verbose=True))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.00,
min_delta=0.001,
patience=3,
verbose=True,
mode='max'))
mode='min'))

torch_estimator = hvd.TorchEstimator(backend=backend,
store=store,
Expand Down
15 changes: 2 additions & 13 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -206,9 +206,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -249,8 +246,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None):
profiler=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -264,8 +260,7 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None,
checkpoint_callback=None)
profiler=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -338,12 +333,6 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down
28 changes: 20 additions & 8 deletions horovod/spark/lightning/remote.py
Expand Up @@ -53,7 +53,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -99,6 +98,8 @@ def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
hvd.init()
_checkpoint_callback = None
require_checkpoint = False

with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
Expand All @@ -115,6 +116,7 @@ def train(serialized_model):
elif isinstance(logger, CometLogger) and logger._experiment_key is None:
# Resume logger experiment key if passed correctly from CPU.
train_logger = CometLogger(
save_dir=logs_path,
api_key=logger.api_key,
experiment_key=logger_experiment_key,
)
Expand All @@ -123,20 +125,24 @@ def train(serialized_model):
else:
# use logger passed in.
train_logger = logger
train_logger.save_dir = logs_path
print(f"Setup logger: Using logger passed from estimator: {train_logger}")

# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
cb.dir_path = ckpt_dir
cb.filename = ckpt_filename
_checkpoint_callback = cb
require_checkpoint = True
break
if not _checkpoint_callback:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
Expand Down Expand Up @@ -224,7 +230,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
remote_store.sync(logs_path)

# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
if require_checkpoint:
if verbose:
print("load from checkpoint best model path:",
_checkpoint_callback.best_model_path)
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
else:
best_model = model
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

Expand Down