Skip to content

Commit

Permalink
Cleaner datadir management for some tests (#15791)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 25, 2022
1 parent f171657 commit 0d98689
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 96 deletions.
5 changes: 1 addition & 4 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ def predict_dataloader(self) -> DataLoader:


class BoringDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
def __init__(self) -> None:
super().__init__()
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None
self.random_full = RandomDataset(32, 64 * 4)

def setup(self, stage: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from torch import Tensor

from lightning_lite.utilities.types import _PATH
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
Expand Down Expand Up @@ -125,14 +126,14 @@ class CSVLogger(Logger):

def __init__(
self,
save_dir: str,
save_dir: _PATH,
name: str = "lightning_logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
flush_logs_every_n_steps: int = 100,
):
super().__init__()
self._save_dir = save_dir
self._save_dir = os.fspath(save_dir)
self._name = name or ""
self._version = version
self._prefix = prefix
Expand Down
12 changes: 9 additions & 3 deletions tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,24 @@ def test_helper_boringdatamodule_with_verbose_setup():
dm.setup("test")


class DataDirDataModule(BoringDataModule):
def __init__(self, data_dir: str):
super().__init__()
self.data_dir = data_dir


def test_dm_add_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
parser = DataDirDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
assert args.data_dir == str(tmpdir)


def test_dm_init_from_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
parser = DataDirDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
dm = BoringDataModule.from_argparse_args(args)
dm = DataDirDataModule.from_argparse_args(args)
dm.prepare_data()
dm.setup("fit")
assert dm.data_dir == args.data_dir == str(tmpdir)
Expand Down

0 comments on commit 0d98689

Please sign in to comment.