Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_history_in_lightning_estimator #3214

Merged
merged 1 commit into from Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 17 additions & 7 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -160,6 +160,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
train_steps_per_epoch: (Optional) Number of steps to train each epoch. Useful for testing
that model trains successfully. Defaults to training the entire
dataset each epoch.
trainer_args: (Optional) Dict of args to pass to trainer, it will be used to add/override the args which estimator gives to trainer.
transformation_fn: (Optional) function that takes a row as its parameter and returns a
modified row that is then fed into the train or validation step.
This transformation is applied after batching. See Petastorm
Expand Down Expand Up @@ -206,6 +207,10 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

trainer_args = Param(Params._dummy(), 'trainer_args',
'Dict of args to pass to trainer, it will be used to add/override the args which estimator gives to trainer. ')


@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -236,6 +241,7 @@ def __init__(self,
validation_steps_per_epoch=None,
transformation_fn=None,
train_reader_num_workers=None,
trainer_args=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
Expand All @@ -260,7 +266,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None)
profiler=None,
trainer_args=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -333,6 +340,12 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setTrainerArgs(self, value):
return self._set(trainer_args=value)

def getTrainerArgs(self):
return self.getOrDefault(self.trainer_args)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down Expand Up @@ -426,19 +439,16 @@ def _read_checkpoint(self, run_id):
return store.read(last_ckpt_path)

def _create_model(self, run_results, run_id, metadata):
serialized_checkpoint = run_results[0]
serialized_checkpoint, history = run_results[0]
serialized_checkpoint.seek(0)
best_checkpoint = torch.load(serialized_checkpoint, map_location=torch.device('cpu'))

model = copy.deepcopy(self.getModel())
# optimizer = copy.deepcopy(self.getOptimizer())

model.load_state_dict(best_checkpoint['model'])

model.eval()

# optimizer.load_state_dict(best_checkpoint['optimizer'])
history = None
# Optimizer is part of the model no need to return it to transformer.
# TODO: (Pengz) Update the latest state of the optimizer in the model for retraining.
optimizer = None

return self.get_model_class()(**self._get_model_kwargs(
Expand Down
16 changes: 14 additions & 2 deletions horovod/spark/lightning/remote.py
Expand Up @@ -59,6 +59,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
data_module = estimator.getDataModule() if estimator.getDataModule() else PetastormDataModule
loader_num_epochs = estimator.getLoaderNumEpochs()
verbose = (estimator.getVerbose() > 0)
trainer_args = estimator.getTrainerArgs()

# get logger
logger = estimator.getLogger()
Expand Down Expand Up @@ -184,6 +185,10 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if _num_gpus is None:
_num_gpus = 1 if cuda_available else 0

# Set bar refresh to 1 / epoch, detailed loss and metrics is avaialbe in logger,
# no need to print in screen here. User can still override this in trainer_args
progress_bar_refresh_rate = _train_steps_per_epoch

kwargs = {'accelerator': 'horovod',
'gpus': _num_gpus,
'callbacks': callbacks,
Expand All @@ -192,10 +197,13 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'log_every_n_steps': log_every_n_steps,
'num_sanity_val_steps': 0,
'reload_dataloaders_every_epoch': False,
'progress_bar_refresh_rate': _train_steps_per_epoch // 10,
'progress_bar_refresh_rate': progress_bar_refresh_rate,
'terminate_on_nan': terminate_on_nan,
'profiler': profiler
}
if trainer_args:
kwargs.update(trainer_args)

print("Creating trainer with: \n ", kwargs)
trainer = Trainer(**kwargs)

Expand Down Expand Up @@ -244,7 +252,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
output = {'model': module.state_dict()}

torch.save(output, serialized_checkpoint)
return serialized_checkpoint

# Save logged metrics as history, which will saved in transformer.
history = trainer.logged_metrics

return serialized_checkpoint, history
return train


Expand Down
33 changes: 33 additions & 0 deletions test/integration/test_spark_lightning.py
Expand Up @@ -913,6 +913,39 @@ def val_dataloader(self):
assert len(pred) == 1
assert pred.dtype == torch.float32

"""
Test override trainer args.
"""
def test_model_override_trainer_args(self):
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

with spark_session('test_fit_model') as spark:
df = create_noisy_xor_data(spark)
model = create_xor_model()

with tempdir() as dir:

with local_store() as store:
torch_estimator = hvd_spark.TorchEstimator(
num_proc=2,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=4,
epochs=2,
verbose=2,
trainer_args={'stochastic_weight_avg': True})

torch_model = torch_estimator.fit(df)

# TODO: Find a way to pass log metrics from remote, and assert base on the logger.
trained_model = torch_model.getModel()
pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
assert len(pred) == 1
assert pred.dtype == torch.float32

def check_fail(dir, rank, epoch, batch):
if dir:
Expand Down