Skip to content

Commit

Permalink
feat: get_transforms PoC
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Jun 16, 2021
1 parent b33eaad commit 14a7e7e
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 40 deletions.
21 changes: 2 additions & 19 deletions CHANGELOG.md
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `utils.ddp_sync_run` function for synchronous ddp run
- CIFAR10 and CIFAR100 datasets from torchvision (no cv-based requirements)
- [Catalyst Engines demo](https://github.com/catalyst-team/catalyst/tree/master/examples/engines)
- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231))
- transform from params support for config API added ([#1236](https://github.com/catalyst-team/catalyst/pull/1236))

### Changed

Expand All @@ -35,25 +37,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Hydra hotfix due to `1.1.0` version changes


## [21.06] - YYYY-MM-DD

### Added

- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231))

### Changed

-

### Removed

-

### Fixed

- batch overfit test fixed ([#1232](https://github.com/catalyst-team/catalyst/pull/1232/files))


## [21.05] - 2021-05-31

### Added
Expand Down
3 changes: 2 additions & 1 deletion catalyst/registry.py
Expand Up @@ -12,7 +12,8 @@
def _transforms_loader(r: registry.Registry):
from catalyst.data import transforms as t

r.add_from_module(t, prefix=["catalyst.", "C."])
# add `'transform.'` prefix to avoid nameing conflicts with other catalyst modules
r.add_from_module(t, prefix=["transform."])


REGISTRY.late_add(_transforms_loader)
Expand Down
42 changes: 36 additions & 6 deletions catalyst/runners/config.py
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List
from collections import OrderedDict
from copy import deepcopy
import logging
Expand Down Expand Up @@ -215,12 +215,44 @@ def get_loggers(self) -> Dict[str, ILogger]:

return loggers

def get_dataset_from_params(self, **params) -> "Dataset":
def _get_transform_from_params(self, **params) -> Callable:
"""Creates transformation from ``**params`` parameters."""
recursion_keys = params.pop("_transforms_", ("transforms",))
for key in recursion_keys:
if key in params:
params[key] = [
self._get_transform_from_params(**transform_params)
for transform_params in params[key]
]

transform = REGISTRY.get_from_params(**params)
return transform

def get_transform(self, **params) -> Callable:
"""
Returns the data transforms for a given dataset.
Args:
**params: parameters of the transformation
Returns:
Data transformation to use
"""
# make a copy of params since we don't want to modify config
params = deepcopy(params)

transform = self._get_transform_from_params(**params)
return transform

def _get_dataset_from_params(self, **params) -> "Dataset":
"""Creates dataset from ``**params`` parameters."""
params = deepcopy(params)

dataset = REGISTRY.get_from_params(**params)
transform_params: dict = params.pop("transform", None)
if transform_params is not None:
params["transform"] = self.get_transform(**transform_params)

dataset = REGISTRY.get_from_params(**params)
return dataset

def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
Expand All @@ -232,12 +264,11 @@ def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
Returns:
Dict: datasets objects
"""
params = deepcopy(self._stage_config[stage]["loaders"]["datasets"])

datasets = [
(key, self.get_dataset_from_params(**dataset_params))
(key, self._get_dataset_from_params(**dataset_params))
for key, dataset_params in params.items()
]
return OrderedDict(datasets)
Expand All @@ -251,7 +282,6 @@ def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
Returns:
Dict: loaders objects
"""
loaders_params = deepcopy(self._stage_config[stage]["loaders"])
loaders_params.pop("datasets", None)
Expand Down
33 changes: 27 additions & 6 deletions catalyst/runners/hydra.py
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List
from collections import OrderedDict
from copy import deepcopy
import logging
Expand Down Expand Up @@ -176,10 +176,32 @@ def get_loggers(self) -> Dict[str, ILogger]:

return loggers

def get_dataset_from_params(self, params: DictConfig) -> "Dataset":
def _get_transform_from_params(self, params: DictConfig) -> Callable:
transform: Callable = hydra.utils.instantiate(params)
return transform

def get_transform(self, params: DictConfig) -> Callable:
"""
Returns the data transforms for a dataset.
Args:
params: parameters of the transformation
Returns:
Data transformations to use
"""
transform = self._get_transform_from_params(params)
return transform

def _get_dataset_from_params(self, params: DictConfig) -> "Dataset":
"""Creates dataset from ``**params`` parameters."""
dataset = hydra.utils.instantiate(params)
raise dataset
transform_params = params.pop("transform", None)
if transform_params:
transform = self.get_transform(transform_params)
dataset = hydra.utils.instantiate(params, transform=transform)
else:
dataset = hydra.utils.instantiate(params)
return dataset

def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
"""
Expand All @@ -190,13 +212,12 @@ def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
Returns:
Dict: datasets objects
"""
datasets_params = self._config.stages[stage].loaders.datasets
datasets_params = OmegaConf.to_container(datasets_params, resolve=True)

datasets = {
key: self.get_dataset_from_params(params) for key, params in datasets_params.items()
key: self._get_dataset_from_params(params) for key, params in datasets_params.items()
}
return OrderedDict(datasets)

Expand Down
4 changes: 4 additions & 0 deletions examples/mnist_stages/config.yml
Expand Up @@ -57,12 +57,16 @@ stages:
_target_: MNIST
root: *dataset_root
train: True
transform:
_target_: transform.ToTensor
download: True
num_samples_per_class: 320
valid:
_target_: MNIST
root: *dataset_root
train: False
transform:
_target_: transform.ToTensor
download: True

criterion:
Expand Down
4 changes: 4 additions & 0 deletions examples/mnist_stages/config_hydra.yaml
Expand Up @@ -56,11 +56,15 @@ stages:
_target_: catalyst.contrib.datasets.MNIST
root: ${shared.dataset_root}
train: true
transform:
_target_: catalyst.data.transforms.ToTensor
download: true
valid:
_target_: catalyst.contrib.datasets.MNIST
root: ${shared.dataset_root}
train: false
transform:
_target_: catalyst.data.transforms.ToTensor
download: true

criterion:
Expand Down
8 changes: 8 additions & 0 deletions examples/mnist_stages/config_tune.yml
Expand Up @@ -65,11 +65,19 @@ stages:
_target_: MNIST
root: *dataset_root
train: True
transform: &transform
_target_: transform.Compose
transforms:
- _target_: transform.ToTensor
- _target_: transform.Normalize
mean: [0]
std: [1]
download: True
valid:
_target_: MNIST
root: *dataset_root
train: False
transform: *transform
download: True
#
# samplers:
Expand Down
13 changes: 5 additions & 8 deletions examples/mnist_stages/runner.py
Expand Up @@ -25,13 +25,10 @@ def get_model(self, stage: str):
utils.set_requires_grad(layer, requires_grad=False)
return model

def get_transform(self, stage: str = None, mode: str = None):
return ToTensor()


class CustomSupervisedConfigRunner(IRunnerMixin, SupervisedConfigRunner):
def get_dataset_from_params(self, num_samples_per_class=320, **kwargs):
dataset = super().get_dataset_from_params(transform=self.get_transform(), **kwargs)
def _get_dataset_from_params(self, num_samples_per_class=320, **kwargs):
dataset = super()._get_dataset_from_params(**kwargs)
if kwargs.get("train", True):
dataset = {
"dataset": dataset,
Expand All @@ -47,10 +44,10 @@ def get_dataset_from_params(self, num_samples_per_class=320, **kwargs):
from catalyst.dl import SupervisedHydraRunner

class CustomSupervisedHydraRunner(IRunnerMixin, SupervisedHydraRunner):
def get_dataset_from_params(self, params):
num_samples_per_class = 320
def _get_dataset_from_params(self, params):
num_samples_per_class = params.pop("num_samples_per_class", 320)

dataset = hydra.utils.instantiate(params, transform=self.get_transform())
dataset = super()._get_dataset_from_params(params)
if params["train"]:
dataset = {
"dataset": dataset,
Expand Down

0 comments on commit 14a7e7e

Please sign in to comment.