Skip to content

Commit

Permalink
add_history_in_lightning_estimator
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed Oct 11, 2021
1 parent d5a90dc commit 60d5f65
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 9 deletions.
24 changes: 17 additions & 7 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -160,6 +160,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
train_steps_per_epoch: (Optional) Number of steps to train each epoch. Useful for testing
that model trains successfully. Defaults to training the entire
dataset each epoch.
trainer_args: (Optional) Dict of args to pass to trainer, it will be used to add/override the args which estimator gives to trainer.
transformation_fn: (Optional) function that takes a row as its parameter and returns a
modified row that is then fed into the train or validation step.
This transformation is applied after batching. See Petastorm
Expand Down Expand Up @@ -206,6 +207,10 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

trainer_args = Param(Params._dummy(), 'trainer_args',
'Dict of args to pass to trainer, it will be used to add/override the args which estimator gives to trainer. ')


@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -236,6 +241,7 @@ def __init__(self,
validation_steps_per_epoch=None,
transformation_fn=None,
train_reader_num_workers=None,
trainer_args=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
Expand All @@ -260,7 +266,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None)
profiler=None,
trainer_args=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -333,6 +340,12 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setTrainerArgs(self, value):
return self._set(trainer_args=value)

def getTrainerArgs(self):
return self.getOrDefault(self.trainer_args)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down Expand Up @@ -426,19 +439,16 @@ def _read_checkpoint(self, run_id):
return store.read(last_ckpt_path)

def _create_model(self, run_results, run_id, metadata):
serialized_checkpoint = run_results[0]
serialized_checkpoint, history = run_results[0]
serialized_checkpoint.seek(0)
best_checkpoint = torch.load(serialized_checkpoint, map_location=torch.device('cpu'))

model = copy.deepcopy(self.getModel())
# optimizer = copy.deepcopy(self.getOptimizer())

model.load_state_dict(best_checkpoint['model'])

model.eval()

# optimizer.load_state_dict(best_checkpoint['optimizer'])
history = None
# Optimizer is part of the model no need to return it to transformer.
# TODO: (Pengz) Update the latest state of the optimizer in the model for retraining.
optimizer = None

return self.get_model_class()(**self._get_model_kwargs(
Expand Down
16 changes: 14 additions & 2 deletions horovod/spark/lightning/remote.py
Expand Up @@ -59,6 +59,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
data_module = estimator.getDataModule() if estimator.getDataModule() else PetastormDataModule
loader_num_epochs = estimator.getLoaderNumEpochs()
verbose = (estimator.getVerbose() > 0)
trainer_args = estimator.getTrainerArgs()

# get logger
logger = estimator.getLogger()
Expand Down Expand Up @@ -184,6 +185,10 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if _num_gpus is None:
_num_gpus = 1 if cuda_available else 0

# Set bar refresh to 1 / epoch, detailed loss and metrics is avaialbe in logger,
# no need to print in screen here. User can still override this in trainer_args
progress_bar_refresh_rate = _train_steps_per_epoch

kwargs = {'accelerator': 'horovod',
'gpus': _num_gpus,
'callbacks': callbacks,
Expand All @@ -192,10 +197,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'log_every_n_steps': log_every_n_steps,
'num_sanity_val_steps': 0,
'reload_dataloaders_every_epoch': False,
'progress_bar_refresh_rate': _train_steps_per_epoch // 10,
'progress_bar_refresh_rate': progress_bar_refresh_rate,
'terminate_on_nan': terminate_on_nan,
'profiler': profiler
}
if trainer_args:
kwargs.update(trainer_args)

print("Creating trainer with: \n ", kwargs)
trainer = Trainer(**kwargs)

Expand Down Expand Up @@ -244,7 +252,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
output = {'model': module.state_dict()}

torch.save(output, serialized_checkpoint)
return serialized_checkpoint

# Save logged metrics as history, which will saved in transformer.
history = trainer.logged_metrics

return serialized_checkpoint, history
return train


Expand Down
33 changes: 33 additions & 0 deletions test/integration/test_spark_lightning.py
Expand Up @@ -913,6 +913,39 @@ def val_dataloader(self):
assert len(pred) == 1
assert pred.dtype == torch.float32

"""
Test override trainer args.
"""
def test_model_override_trainer_args(self):
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

with spark_session('test_fit_model') as spark:
df = create_noisy_xor_data(spark)
model = create_xor_model()

with tempdir() as dir:

with local_store() as store:
torch_estimator = hvd_spark.TorchEstimator(
num_proc=2,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=4,
epochs=2,
verbose=2,
trainer_args={'stochastic_weight_avg': True})

torch_model = torch_estimator.fit(df)

# TODO: Find a way to pass log metrics from remote, and assert base on the logger.
trained_model = torch_model.getModel()
pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
assert len(pred) == 1
assert pred.dtype == torch.float32

def check_fail(dir, rank, epoch, batch):
if dir:
Expand Down

0 comments on commit 60d5f65

Please sign in to comment.