You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to train a Lightning model that inherits from pl.LightningModule and implements a simple feed-forward network. The issue is that when I run it, it spits out the below error trace coming from trainer.fit(). I found this very similar issue, where downgrading to torchmetrics<=0.5.0 fixed the issue, but that is not possible in my case as v2.2.0 of pytorch-lightning is not compatible with such an old version of torchmetrics. I tried downgrading to 0.7., the oldest compatible version, but it led to a different error also in the trainer.fit method.
Thanks for your attention and I would appreciate any help with this.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Belowisthemodelclassdefinitionimportpytorch_lightningasplimporttorchimportnumpyasnpfromtorch.nnimportMSELoss, L1LossfromtorchmetricsimportR2Scoretorch.random.manual_seed(123)
classLightningModelSimple(pl.LightningModule):
def__init__(
self,
latent_model,
readout_model=None,
losses={},
metrics=[],
gpu=True,
learning_rate=0.001,
weight_decay=0.0,
):
super().__init__()
self.save_hyperparameters()
self.latent_model=latent_modelifreadout_modelisNone:
self.readout_model=torch.nn.Identity()
else:
self.readout_model=readout_model# lossesif"target"inlosses:
self.loss_target=losses["target"]
else:
self.loss_target=Noneif"latent_target"inlosses:
self.loss_latent_target=losses["latent_target"]
self.weight_loss_latent_target=losses["weight_loss_latent_target"]
else:
self.loss_latent_target=Noneself.gpu=gpuself.metrics=metricsself.learning_rate=learning_rateself.weight_decay=weight_decaydefforward(self, x):
x_latent=self.latent_model(x)
y=self.readout_model(x_latent)
returnydefstep(self, partition, batch, batch_idx):
spectra, target_glucose=batch# get latent predictionsself.pred_latent=self.latent_model(spectra.float())
# get glucose predictionsself.pred_glucose=self.readout_model(self.pred_latent)
# compute lossesloss=0ifself.loss_targetisnotNone:
loss+=self.loss_target(self.pred_glucose, target_glucose)
self.log(partition+"_loss_target", loss, on_epoch=True)
ifself.loss_latent_targetisnotNone:
loss_latent_target= (
self.weight_loss_latent_target*self.loss_latent_target(self.pred_latent, target_glucose.unsqueeze(1))
)
self.log(
partition+"_loss_latent_target", loss_latent_target, on_epoch=True
)
loss+=loss_latent_targetself.log(partition+"_loss_total", loss, on_epoch=True)
formetric_name, metricinself.metrics:
self.log(
partition+"_"+metric_name,
metric(self.pred_glucose, target_glucose),
on_epoch=True,
)
returnlossdeftraining_step(self, batch, batch_idx):
returnself.step("train", batch, batch_idx)
defvalidation_step(self, batch, batch_idx):
returnself.step("val", batch, batch_idx)
deftest_step(self, batch, batch_idx):
returnself.step("test", batch, batch_idx)
defconfigure_optimizers(self):
returntorch.optim.Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
)
Thisshouldgoinadifferentfilecalledhelpers.pydeflog_parameter(params, parser, param_name=""):
ifisinstance(params, dict):
forkeyinparams.keys():
ifkey=="class_path":
parser=log_parameter(params[key], parser, param_name)
else:
parser=log_parameter(params[key], parser, key)
else:
parser.add_argument("--"+param_name, type=type(params), default=params)
returnparserdefupdate(config_data, params):
fork, vinparams.items():
ifisinstance(v, collections.abc.Mapping):
config_data[k] =update(config_data.get(k, {}), v)
else:
config_data[k] =vreturnconfig_datadeftrain_model(config_file, **kwargs):
loader=yaml.SafeLoaderwithopen(config_file, "r") asstream:
config_data=yaml.load(stream, Loader=loader)
if"params"inkwargs:
config_data=update(config_data, kwargs["params"])
if"latent_model"inkwargs:
config_data["lightning_model"]["init_args"]["latent_model"] =kwargs[
"latent_model"
]
# experiment_name = config_data["experiment_name"]n_epochs=config_data["n_epochs"]
pl.seed_everything(1234)
# add arguments to parserparser=ArgumentParser(conflict_handler="resolve")
parser.add_argument(
"--auto-select-gpus", default=True, help="run automatically on GPU if available"
)
parser.add_argument("--max-epochs", default=n_epochs, type=int)
parser.add_argument("gpus", type=int, default=1)
parser=log_parameter(config_data, parser)
# parse arguments to trainerargs=parser.parse_args()
ifargs.gpus==1:
device="cuda"elifargs.gpus==0:
device="cpu"# create mlflow experiment if it doesn't yet existtry:
current_experiment=dict(mlflow.get_experiment_by_name(args.experiment_name))
experiment_id=current_experiment["experiment_id"]
except:
print("creating new experiment")
experiment_id=mlflow.create_experiment(args.experiment_name)
# # start experimentwithmlflow.start_run(experiment_id=experiment_id) asrun:
withopen("log.txt", "a") aslog_file:
log_file.write("'"+str(run.info.run_id) +"'"+", ")
path_mlflow_results= (
"mlruns/"+str(experiment_id) +"/"+str(run.info.run_id)
)
path_checkpoints=path_mlflow_results+"/checkpoints"# copy yaml file to mlfow results# TODO: this is a hack for now, this should automatically be logged# with open(path_mlflow_results + "/" + config_file, "w") as f:withopen(path_mlflow_results+"/config.yaml", "w") asf:
yaml.dump(config_data, f)
# initialize dataloaderconfig_data=initialize_datamodule(config_data)
datamodule=config_data["datamodule"]
# extract key for model selectionloss_key=config_data["metric_model_selection"]
if (
config_data["datamodule"].split_label_val=="Barcode"and"val_"inloss_key[0]
):
raiseValueError(
"split_label_val=Barcode with metric_model_selection=",
loss_key,
" introduces data leakage",
)
# initialize lightning modelif (
config_data["lightning_model"]["class_path"]
=="models.lightning_model.LightningModel"
):
use_val_test_data_in_train=Trueelif (
config_data["lightning_model"]["class_path"]
=="models.lightning_model.LightningModelSimple"
):
use_val_test_data_in_train=Falseconfig_data=initialize_modules(config_data)
lightning_model=config_data["lightning_model"]
print(type(lightning_model))
print(type(datamodule))
# monitor different metrics depending on loss variablecheckpoints= []
monitored_metrics=config_data["monitored_metrics"]
fori, (me, mo) inenumerate(monitored_metrics):
ckpt=pl.callbacks.ModelCheckpoint(
monitor=me,
mode=mo,
dirpath=path_checkpoints,
filename="{epoch:02d}-{"+me+":.4f}",
save_top_k=1,
)
checkpoints.append(ckpt)
# checkpoints.append(# pl.callbacks.ModelCheckpoint(# dirpath=path_checkpoints,# filename="every_n_{epoch:02d}",# every_n_epochs=10,# save_top_k=-1, # <--- this is important!# )# )# log all parametermlflow.pytorch.autolog()
forarginvars(args):
mlflow.log_param(arg, getattr(args, arg))
# train modeltrainer=pl.Trainer(max_epochs=n_epochs, logger=True, callbacks=checkpoints)
# TODO: this is very hackey and should be revisited# we create a combined dataloader which is the same for train/validation/test# batching is applied to the train dataloader, thus there will be multiple batches with the batch size defined in config.yaml# the validation and test datloaders only have one batch which has the size of the entire validation/test set# insight the lightning module we read out the validation and test batch at step 0 and save it as a class# attribute such that all validation and test data can be used in all training stepsifuse_val_test_data_in_train:
datamodule.setup(stage="")
iterables_train= {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
iterables_val= {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
iterables_test= {
"train": datamodule.train_dataloader(),
"val": datamodule.val_dataloader(),
"test": datamodule.test_dataloader(),
}
combined_loader_train=CombinedLoader(iterables_train, mode="max_size")
combined_loader_val=CombinedLoader(iterables_val, mode="max_size")
combined_loader_test=CombinedLoader(iterables_test, mode="max_size")
trainer.fit(lightning_model, combined_loader_train, combined_loader_val)
else:
trainer.fit(lightning_model, datamodule=datamodule)
# evaluate tests for all monitored metricsckpts=glob.glob(path_checkpoints+"/*")
forckptinckpts:
ifloss_key[0] inckpt:
ifuse_val_test_data_in_train:
result=trainer.test(
dataloaders=combined_loader_test, ckpt_path=ckpt
)
else:
result=trainer.test(datamodule=datamodule, ckpt_path=ckpt)
print(result)
Finallythemainfileimporttorchimportutils.helpersashelperstorch.random.manual_seed(123)
if__name__=="__main__":
# profil data# train_model("config_profil_latent.yaml")# train_model("config_profil_readout.yaml")# train_model("config_profil.yaml")# train_model("config_profil_simple.yaml")forweight_decayin [1.0]:
forval_subjectinrange(0, 14):
params= {
"datamodule": {
"init_args": {
"val_index": [val_subject],
"test_index": [],
}
},
"lightning_model": {
"init_args": {
"weight_decay": weight_decay,
}
},
}
helpers.train_model("config_profil_simple.yaml", params=params)
Error messages and logs
Traceback (most recent call last):
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 969, in _run
_log_hyperparams(self)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/utilities.py", line 95, in _log_hyperparams
logger.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
self.experiment.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
yaml.dump(v)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
node = self.represent_data(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
return self.represent_mapping(tag+function_name, value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/pap_spiden_com/spiden_ds/experiments/artemis/main.py", line 28, in <module>
helpers.train_model("config_profil_simple.yaml", params=params)
File "/home/pap_spiden_com/spiden_ds/experiments/artemis/utils/helpers.py", line 191, in train_model
trainer.fit(lightning_model, datamodule=datamodule)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 573, in safe_patch_function
patch_function(call_original, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 252, in patch_with_managed_run
result = patch_function(original, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/pytorch/_lightning_autolog.py", line 386, in patched_fit
result = original(self, *args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 554, in call_original
return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 489, in call_original_fn_with_event_logging
original_fn_result = original_fn(*og_args, **og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py", line 551, in _original_fn
original_result = original(*_og_args, **_og_kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 67, in _call_and_handle_interrupt
logger.finalize("failed")
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 166, in finalize
self.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 42, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py", line 157, in save
self.experiment.save()
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/loggers/csv_logs.py", line 67, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 354, in save_hparams_to_yaml
yaml.dump(v)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 253, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 27, in represent
node = self.represent_data(data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 199, in represent_list
return self.represent_sequence('tag:yaml.org,2002:seq', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 356, in represent_object
return self.represent_mapping(tag+function_name, value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "/opt/conda/envs/artemis/lib/python3.10/site-packages/yaml/representer.py", line 330, in represent_object
dictitems = dict(dictitems)
ValueError: dictionary update sequence element #0 has length 1; 2 is required
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
The text was updated successfully, but these errors were encountered:
Bug description
I am trying to train a Lightning model that inherits from pl.LightningModule and implements a simple feed-forward network. The issue is that when I run it, it spits out the below error trace coming from trainer.fit(). I found this very similar issue, where downgrading to
torchmetrics<=0.5.0
fixed the issue, but that is not possible in my case as v2.2.0 of pytorch-lightning is not compatible with such an old version of torchmetrics. I tried downgrading to 0.7., the oldest compatible version, but it led to a different error also in the trainer.fit method.Thanks for your attention and I would appreciate any help with this.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: