Skip to content

Commit

Permalink
feat: get_dataset_from_params (config api and hydra) added
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Jun 7, 2021
1 parent e8adb3c commit ef2c857
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 77 deletions.
21 changes: 20 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,26 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


## [21.05] - YYYY-MM-DD
## [21.06] - YYYY-MM-DD

### Added

- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231))

### Changed

-

### Removed

-

### Fixed

-


## [21.05] - 2021-05-31

### Added

Expand Down
40 changes: 33 additions & 7 deletions catalyst/runners/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, Dict, List
from collections import OrderedDict
from copy import deepcopy
from functools import partial
import logging
import os

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from catalyst.callbacks import CheckpointCallback, ICheckpointCallback
from catalyst.callbacks.batch_overfit import BatchOverfitCallback
Expand Down Expand Up @@ -216,6 +215,34 @@ def get_loggers(self) -> Dict[str, ILogger]:

return loggers

@staticmethod
def get_dataset_from_params(**params) -> "Dataset":
"""Creates dataset from ``**params`` parameters."""
params = deepcopy(params)

dataset = REGISTRY.get_from_params(**params)

return dataset

def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
"""
Returns datasets for a given stage.
Args:
stage: stage name
Returns:
Dict: datasets objects
"""
params = deepcopy(self._stage_config[stage]["loaders"]["datasets"])

datasets = [
(key, self.get_dataset_from_params(**dataset_params))
for key, dataset_params in params.items()
]
return OrderedDict(datasets)

def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
"""
Returns loaders for a given stage.
Expand All @@ -227,12 +254,11 @@ def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
Dict: loaders objects
"""
loaders_params = dict(self._stage_config[stage]["loaders"])
loaders_params = deepcopy(self._stage_config[stage]["loaders"])
loaders_params.pop("datasets", None)

loaders = get_loaders_from_params(
datasets_fn=partial(self.get_datasets, stage=stage),
initial_seed=self.seed,
stage=stage,
**loaders_params,
datasets=self.get_datasets(stage=stage), initial_seed=self.seed, **loaders_params,
)
return loaders

Expand Down
35 changes: 29 additions & 6 deletions catalyst/runners/hydra.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any, Dict, List
from collections import OrderedDict
from copy import deepcopy
from functools import partial
import logging
import os

import hydra
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from catalyst.callbacks import CheckpointCallback, ICheckpointCallback
from catalyst.callbacks.batch_overfit import BatchOverfitCallback
Expand Down Expand Up @@ -177,6 +176,31 @@ def get_loggers(self) -> Dict[str, ILogger]:

return loggers

@staticmethod
def get_dataset_from_params(params: DictConfig) -> "Dataset":
"""Creates dataset from ``**params`` parameters."""
dataset = hydra.utils.instantiate(params)
raise dataset

def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
"""
Returns datasets for a given stage.
Args:
stage: stage name
Returns:
Dict: datasets objects
"""
datasets_params = self._config.stages[stage].loaders.datasets
datasets_params = OmegaConf.to_container(datasets_params, resolve=True)

datasets = {
key: self.get_dataset_from_params(params) for key, params in datasets_params.items()
}
return OrderedDict(datasets)

def get_loaders(self, stage: str) -> Dict[str, DataLoader]:
"""
Returns loaders for a given stage.
Expand All @@ -190,11 +214,10 @@ def get_loaders(self, stage: str) -> Dict[str, DataLoader]:
"""
loaders_params = self._config.stages[stage].loaders
loaders_params = OmegaConf.to_container(loaders_params, resolve=True)
loaders_params.pop("datasets", None)

loaders = get_loaders_from_params(
datasets_fn=partial(self.get_datasets, stage=stage),
initial_seed=self.seed,
stage=stage,
**loaders_params,
datasets=self.get_datasets(stage=stage), initial_seed=self.seed, **loaders_params,
)
return loaders

Expand Down
23 changes: 7 additions & 16 deletions catalyst/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, Union
from collections import OrderedDict
import copy
from functools import partial
Expand Down Expand Up @@ -82,9 +82,8 @@ def get_loaders_from_params(
per_gpu_scaling: bool = False,
loaders_params: Dict[str, Any] = None,
samplers_params: Dict[str, Any] = None,
datasets: "OrderedDict[str, Union[Dataset, dict]]" = None,
initial_seed: int = 42,
datasets_fn: Callable = None,
**data_params,
) -> "OrderedDict[str, DataLoader]":
"""
Creates pytorch dataloaders from datasets and additional parameters.
Expand All @@ -98,22 +97,18 @@ def get_loaders_from_params(
from ``torch.utils.data.DataLoader``
per_gpu_scaling: boolean flag,
if ``True``, scales batch_size in proportion to the number of GPUs
loaders_params (Dict[str, Any]): additional loaders parameters
samplers_params (Dict[str, Any]): additional sampler parameters
loaders_params: additional loaders parameters
samplers_params: additional sampler parameters
initial_seed: initial seed for ``torch.utils.data.DataLoader``
workers
datasets_fn(Callable): callable function to get dictionary with
``torch.utils.data.Datasets``
**data_params: additional data parameters
or dictionary with ``torch.utils.data.Datasets`` to use for
pytorch dataloaders creation
datasets: ordered dictionary with ``torch.utils.data.Dataset``
Returns:
OrderedDict[str, DataLoader]: dictionary with
``torch.utils.data.DataLoader``
Raises:
NotImplementedError: if datasource is out of `Dataset` or dict
NotImplementedError: if datasource is out of ``Dataset`` or dict
ValueError: if batch_sampler option is mutually
exclusive with distributed
"""
Expand All @@ -129,15 +124,11 @@ def get_loaders_from_params(
assert isinstance(
samplers_params, dict
), f"`samplers_params` should be a Dict. Got: {samplers_params}"
datasets = datasets if datasets is not None else {}

distributed_rank = get_rank()
distributed = distributed_rank > -1

if datasets_fn is not None:
datasets = datasets_fn(**data_params)
else:
datasets = dict(**data_params)

loaders = OrderedDict()
for name, datasource in datasets.items(): # noqa: WPS426
assert isinstance(
Expand Down
15 changes: 13 additions & 2 deletions examples/mnist_stages/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,19 @@ stages:
# default kwargs for `runner.get_loaders`:
batch_size: 32
num_workers: 1
# kwargs for `runner.get_datasets`:
num_samples_per_class: 320

datasets:
train:
# _target_: MNIST
root: *dataset_root
train: True
download: True
num_samples_per_class: 320
valid:
# _target_: MNIST
root: *dataset_root
train: False
download: True

criterion:
_target_: CrossEntropyLoss
Expand Down
14 changes: 12 additions & 2 deletions examples/mnist_stages/config_hydra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,18 @@ stages:
# default kwargs for `runner.get_loaders`:
batch_size: 32
num_workers: 1
# kwargs for `runner.get_datasets`:
num_samples_per_class: 320

datasets:
train:
_target_: catalyst.contrib.datasets.MNIST
root: ${shared.dataset_root}
train: true
download: true
valid:
_target_: catalyst.contrib.datasets.MNIST
root: ${shared.dataset_root}
train: false
download: true

criterion:
_target_: torch.nn.CrossEntropyLoss
Expand Down
22 changes: 11 additions & 11 deletions examples/mnist_stages/config_tune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ stages:
batch_size: 32
num_workers: 1

# datasets:
# train:
# _target_: MNIST
# root: *dataset_root
# train: True
# download: True
# valid:
# _target_: MNIST
# root: *dataset_root
# train: False
# download: True
datasets:
train:
# _target_: MNIST
root: *dataset_root
train: True
download: True
valid:
# _target_: MNIST
root: *dataset_root
train: False
download: True
#
# samplers:
# train:
Expand Down
56 changes: 30 additions & 26 deletions examples/mnist_stages/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,44 @@ def get_model(self, stage: str):
def get_transform(self, stage: str = None, mode: str = None):
return ToTensor()

def get_datasets(
self, stage: str, num_samples_per_class: int = None
) -> "OrderedDict[str, Dataset]":
"""Provides train/validation datasets from MNIST dataset."""
num_samples_per_class = num_samples_per_class or 320
datasets = OrderedDict()
for mode in ("train", "valid"):
dataset = MNIST(
"./data",
train=(mode == "train"),
download=True,
transform=self.get_transform(stage=stage, mode=mode),
)
if mode == "train":
dataset = {
"dataset": dataset,
"sampler": BalanceClassSampler(
labels=dataset.targets, mode=num_samples_per_class
),
}
datasets[mode] = dataset

return datasets


class CustomSupervisedConfigRunner(IRunnerMixin, SupervisedConfigRunner):
pass
def get_dataset_from_params(
self,
root: str = "./data",
train: bool = True,
download: bool = False,
num_samples_per_class=320,
):
dataset = MNIST(root, train=train, download=download, transform=self.get_transform(),)
if train:
dataset = {
"dataset": dataset,
"sampler": BalanceClassSampler(labels=dataset.targets, mode=num_samples_per_class),
}

return dataset


if SETTINGS.hydra_required:
import hydra

from catalyst.dl import SupervisedHydraRunner

class CustomSupervisedHydraRunner(IRunnerMixin, SupervisedHydraRunner):
pass
def get_dataset_from_params(self, params):
num_samples_per_class = 320

dataset = hydra.utils.instantiate(params, transform=self.get_transform())
if params["train"]:
dataset = {
"dataset": dataset,
"sampler": BalanceClassSampler(
labels=dataset.targets, mode=num_samples_per_class
),
}

return dataset

__all__ = ["CustomSupervisedConfigRunner", "CustomSupervisedHydraRunner"]
else:
Expand Down

0 comments on commit ef2c857

Please sign in to comment.