Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing a dataloader to save_hyperparams hangs trainer.fit #19785

Open
lsc64 opened this issue Apr 17, 2024 · 0 comments
Open

Passing a dataloader to save_hyperparams hangs trainer.fit #19785

lsc64 opened this issue Apr 17, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@lsc64
Copy link

lsc64 commented Apr 17, 2024

Bug description

If you pass a torch.utils.data.dataloader.DataLoader object as part of your hyperparameters and save it using in your LightningModule using self.save_hyperparameters(), it will hang trainer.fit

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self, dl):
        super().__init__()
        self.dl = dl
        self.layer = torch.nn.Linear(32, 2)
        self.save_hyperparameters()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(1000000, 64), batch_size=2)

    model = BoringModel(
        dl=test_data,
    )
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        devices=[1],
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
# hangs for 500+ minutes

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.2.1
    - lightning-bolts: 0.7.0
    - lightning-utilities: 0.11.0
    - pytorch-lightning: 2.2.1
    - torch: 2.2.1
    - torchaudio: 2.2.1
    - torchmetrics: 1.3.2
    - torchvision: 0.17.1
  • Packages:
    - absl-py: 2.1.0
    - aiohttp: 3.9.3
    - aiohttp-cors: 0.7.0
    - aiosignal: 1.3.1
    - alembic: 1.13.1
    - aniso8601: 9.0.1
    - annotated-types: 0.6.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.3.0
    - archspec: 0.2.2
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asciitree: 0.3.3
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - attrs: 23.2.0
    - autodp: 0.2.3.1
    - babel: 2.14.0
    - bcrypt: 4.1.2
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - blessed: 1.19.1
    - blinker: 1.7.0
    - huggingface-hub: 0.21.4
    - hydra-core: 1.3.2
    - identify: 2.5.35
    - idna: 3.6
    - importlib-metadata: 7.1.0
    - importlib-resources: 6.1.2
    - install: 1.3.5
    - ipykernel: 6.29.3
    - ipython: 8.22.2
    - ipywidgets: 8.1.2
    - isoduration: 20.11.0
    - itsdangerous: 2.1.2
    - jax: 0.4.25
    - jaxlib: 0.4.25
    - jaxopt: 0.8.3
    - jaxtyping: 0.2.28
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - joblib: 1.3.2
    - json5: 0.9.24
    - jsonpatch: 1.33
    - jsonpointer: 2.4
    - jsonschema: 4.21.1
    - jsonschema-specifications: 2023.12.1
    - jupyter-client: 8.6.0
    - jupyter-core: 5.7.1
    - jupyter-events: 0.9.0
    - jupyter-lsp: 2.2.4
    - jupyter-server: 2.13.0
    - jupyter-server-mathjax: 0.2.6
    - jupyter-server-terminals: 0.5.2
    - jupyterlab: 4.1.5
    - jupyterlab-git: 0.50.0
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.25.4
    - jupyterlab-widgets: 3.0.10
    - kiwisolver: 1.4.5
    - libarchive-c: 5.0
    - libmambapy: 1.5.7
    - lightning: 2.2.1
    - lightning-bolts: 0.7.0
    - lightning-utilities: 0.11.0
    - llvmlite: 0.42.0
    - locket: 1.0.0
    - mako: 1.3.2
    - mamba: 1.5.7
    - markdown: 3.5.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.8.3
    - matplotlib-inline: 0.1.6
    - mdurl: 0.1.2
    - memory-tempfile: 2.2.3
    - menuinst: 2.0.2
    - mistune: 3.0.2
    - ml-dtypes: 0.3.2
    - mlflow: 2.11.0
    - mlflow-skinny: 2.11.0
    - more-itert
    - python-dateutil: 2.9.0.post0
    - python-dp: 1.1.4
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.2.1
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 25.1.2
    - querystring-parser: 1.2.4
    - ray: 2.9.3
    - referencing: 0.33.0
    - regex: 2023.12.25
    - requests: 2.31.0
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rpds-py: 0.18.0
    - rsa: 4.9
    - ruamel.yaml: 0.18.6
    - ruamel.yaml.clib: 0.2.8
    - ruff: 0.3.3
    - safetensors: 0.4.2
    - scikit-learn: 1.4.1.post1
    - scipy: 1.12.0
    - seaborn: 0.13.2
    - send2trash: 1.8.2
    - setuptools: 69.2.0
    - six: 1.16.0
    - skorch: 0.15.0
    - smart-open: 7.0.1
    - smmap: 5.0.0
    - sniffio: 1.3.1
    - sortedcontainers: 2.4.0
    - soupsieve: 2.5
    - sparse: 0.15.1
    - sqlalchemy: 2.0.28
    - sqlparse: 0.4.4
    - stack-data: 0.6.2
    - sympy: 1.12
    - tabulate: 0.9.0
    - tblib: 3.0.0
    - tensorboard: 2.16.2
    - tensorboard-data-server: 0.7.0
    - tensorstore: 0.1.56
    - termcolor: 2.4.0
    - terminado: 0.18.0
    - tfp-nightly: 0.25.0.dev20240318
    - threadpoolctl: 3.4.0
    - timm: 0.9.16
    - tinycss2: 1.2.1
    - tokenizers: 0.15.2
    - toolz: 0.12.1
    - torch: 2.2.1
    - torchaudio: 2.2.1
    - torchmetrics: 1.3.2
    - torchvision: 0.17.1
    - tornado: 6.4
    - tqdm: 4.66.2
    - traitlets: 5.14.1
    - transformers: 4.38.2
    - triton: 2.2.0
    - truststore: 0.8.0
    - typeguard: 2.13.3
    - types-python-dateutil: 2.8.19.20240106
    - typing-extensions: 4.10.0
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uri-template: 1.3.0
    - urllib3: 2.2.1
    - uv: 0.1.22
    - virtualenv: 20.25.1
    - vit-proto: 0.0.0
    - wcwidth: 0.2.13
    - webcolors: 1.13
    - webencodings: 0.5.1
    - websocket-client: 1.7.0
    - werkzeug: 3.0.1
    - wheel: 0.42.0
    - widgetsnbextension: 4.0.10
    - wrapt: 1.16.0
    - yarl: 1.9.4
    - zarr: 2.17.1
    - zict: 3.0.0
    - zipp: 3.18.1
    - zstandard: 0.22.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.11.8
    - release: 5.4.0-173-generic
    - version: Should the dependency, test_tube, be explicity stated in the readme at the top? #191-Ubuntu SMP Fri Feb 2 13:55:07 UTC 2024

More info

There should probably be a warning regarding this, because logging the entire dataloader was definitely not intended.

@lsc64 lsc64 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 17, 2024
@lsc64 lsc64 changed the title Passing a dataloader to save_hyperparams hangs the trainer Passing a dataloader to save_hyperparams hangs trainer.fit Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant