diff --git a/CHANGELOG.md b/CHANGELOG.md index 46ed80623e..723f59c6d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `utils.ddp_sync_run` function for synchronous ddp run - CIFAR10 and CIFAR100 datasets from torchvision (no cv-based requirements) - [Catalyst Engines demo](https://github.com/catalyst-team/catalyst/tree/master/examples/engines) +- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231)) +- transform from params support for config API added ([#1236](https://github.com/catalyst-team/catalyst/pull/1236)) ### Changed @@ -35,25 +37,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Hydra hotfix due to `1.1.0` version changes -## [21.06] - YYYY-MM-DD - -### Added - -- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231)) - -### Changed - -- - -### Removed - -- - -### Fixed - -- batch overfit test fixed ([#1232](https://github.com/catalyst-team/catalyst/pull/1232/files)) - - ## [21.05] - 2021-05-31 ### Added diff --git a/catalyst/registry.py b/catalyst/registry.py index 0ac1bc590d..d0925542d1 100644 --- a/catalyst/registry.py +++ b/catalyst/registry.py @@ -12,7 +12,8 @@ def _transforms_loader(r: registry.Registry): from catalyst.data import transforms as t - r.add_from_module(t, prefix=["catalyst.", "C."]) + # add `'transform.'` prefix to avoid nameing conflicts with other catalyst modules + r.add_from_module(t, prefix=["transform."]) REGISTRY.late_add(_transforms_loader) diff --git a/catalyst/runners/config.py b/catalyst/runners/config.py index 7b7438bd39..2dee7e9657 100644 --- a/catalyst/runners/config.py +++ b/catalyst/runners/config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from collections import OrderedDict from copy import deepcopy import logging @@ -215,12 +215,44 @@ def get_loggers(self) -> Dict[str, ILogger]: return loggers - def get_dataset_from_params(self, **params) -> "Dataset": + def _get_transform_from_params(self, **params) -> Callable: + """Creates transformation from ``**params`` parameters.""" + recursion_keys = params.pop("_transforms_", ("transforms",)) + for key in recursion_keys: + if key in params: + params[key] = [ + self._get_transform_from_params(**transform_params) + for transform_params in params[key] + ] + + transform = REGISTRY.get_from_params(**params) + return transform + + def get_transform(self, **params) -> Callable: + """ + Returns the data transforms for a given dataset. + + Args: + **params: parameters of the transformation + + Returns: + Data transformation to use + """ + # make a copy of params since we don't want to modify config + params = deepcopy(params) + + transform = self._get_transform_from_params(**params) + return transform + + def _get_dataset_from_params(self, **params) -> "Dataset": """Creates dataset from ``**params`` parameters.""" params = deepcopy(params) - dataset = REGISTRY.get_from_params(**params) + transform_params: dict = params.pop("transform", None) + if transform_params is not None: + params["transform"] = self.get_transform(**transform_params) + dataset = REGISTRY.get_from_params(**params) return dataset def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]": @@ -232,12 +264,11 @@ def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]": Returns: Dict: datasets objects - """ params = deepcopy(self._stage_config[stage]["loaders"]["datasets"]) datasets = [ - (key, self.get_dataset_from_params(**dataset_params)) + (key, self._get_dataset_from_params(**dataset_params)) for key, dataset_params in params.items() ] return OrderedDict(datasets) @@ -251,7 +282,6 @@ def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]": Returns: Dict: loaders objects - """ loaders_params = deepcopy(self._stage_config[stage]["loaders"]) loaders_params.pop("datasets", None) diff --git a/catalyst/runners/hydra.py b/catalyst/runners/hydra.py index 4307d8451e..b9d9cf90e4 100644 --- a/catalyst/runners/hydra.py +++ b/catalyst/runners/hydra.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from collections import OrderedDict from copy import deepcopy import logging @@ -176,10 +176,32 @@ def get_loggers(self) -> Dict[str, ILogger]: return loggers - def get_dataset_from_params(self, params: DictConfig) -> "Dataset": + def _get_transform_from_params(self, params: DictConfig) -> Callable: + transform: Callable = hydra.utils.instantiate(params) + return transform + + def get_transform(self, params: DictConfig) -> Callable: + """ + Returns the data transforms for a dataset. + + Args: + params: parameters of the transformation + + Returns: + Data transformations to use + """ + transform = self._get_transform_from_params(params) + return transform + + def _get_dataset_from_params(self, params: DictConfig) -> "Dataset": """Creates dataset from ``**params`` parameters.""" - dataset = hydra.utils.instantiate(params) - raise dataset + transform_params = params.pop("transform", None) + if transform_params: + transform = self.get_transform(transform_params) + dataset = hydra.utils.instantiate(params, transform=transform) + else: + dataset = hydra.utils.instantiate(params) + return dataset def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]": """ @@ -190,13 +212,12 @@ def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]": Returns: Dict: datasets objects - """ datasets_params = self._config.stages[stage].loaders.datasets datasets_params = OmegaConf.to_container(datasets_params, resolve=True) datasets = { - key: self.get_dataset_from_params(params) for key, params in datasets_params.items() + key: self._get_dataset_from_params(params) for key, params in datasets_params.items() } return OrderedDict(datasets) diff --git a/examples/mnist_stages/config.yml b/examples/mnist_stages/config.yml index 49f88722ac..34ffb886a3 100644 --- a/examples/mnist_stages/config.yml +++ b/examples/mnist_stages/config.yml @@ -57,12 +57,16 @@ stages: _target_: MNIST root: *dataset_root train: True + transform: + _target_: transform.ToTensor download: True num_samples_per_class: 320 valid: _target_: MNIST root: *dataset_root train: False + transform: + _target_: transform.ToTensor download: True criterion: diff --git a/examples/mnist_stages/config_hydra.yaml b/examples/mnist_stages/config_hydra.yaml index d98bd6e42d..daf44fa53b 100644 --- a/examples/mnist_stages/config_hydra.yaml +++ b/examples/mnist_stages/config_hydra.yaml @@ -56,11 +56,15 @@ stages: _target_: catalyst.contrib.datasets.MNIST root: ${shared.dataset_root} train: true + transform: + _target_: catalyst.data.transforms.ToTensor download: true valid: _target_: catalyst.contrib.datasets.MNIST root: ${shared.dataset_root} train: false + transform: + _target_: catalyst.data.transforms.ToTensor download: true criterion: diff --git a/examples/mnist_stages/config_tune.yml b/examples/mnist_stages/config_tune.yml index 0908c4e5fe..7024a950c3 100644 --- a/examples/mnist_stages/config_tune.yml +++ b/examples/mnist_stages/config_tune.yml @@ -65,11 +65,19 @@ stages: _target_: MNIST root: *dataset_root train: True + transform: &transform + _target_: transform.Compose + transforms: + - _target_: transform.ToTensor + - _target_: transform.Normalize + mean: [0] + std: [1] download: True valid: _target_: MNIST root: *dataset_root train: False + transform: *transform download: True # # samplers: diff --git a/examples/mnist_stages/runner.py b/examples/mnist_stages/runner.py index e41a545106..c6cab073b7 100644 --- a/examples/mnist_stages/runner.py +++ b/examples/mnist_stages/runner.py @@ -25,13 +25,10 @@ def get_model(self, stage: str): utils.set_requires_grad(layer, requires_grad=False) return model - def get_transform(self, stage: str = None, mode: str = None): - return ToTensor() - class CustomSupervisedConfigRunner(IRunnerMixin, SupervisedConfigRunner): - def get_dataset_from_params(self, num_samples_per_class=320, **kwargs): - dataset = super().get_dataset_from_params(transform=self.get_transform(), **kwargs) + def _get_dataset_from_params(self, num_samples_per_class=320, **kwargs): + dataset = super()._get_dataset_from_params(**kwargs) if kwargs.get("train", True): dataset = { "dataset": dataset, @@ -47,10 +44,10 @@ def get_dataset_from_params(self, num_samples_per_class=320, **kwargs): from catalyst.dl import SupervisedHydraRunner class CustomSupervisedHydraRunner(IRunnerMixin, SupervisedHydraRunner): - def get_dataset_from_params(self, params): - num_samples_per_class = 320 + def _get_dataset_from_params(self, params): + num_samples_per_class = params.pop("num_samples_per_class", 320) - dataset = hydra.utils.instantiate(params, transform=self.get_transform()) + dataset = super()._get_dataset_from_params(params) if params["train"]: dataset = { "dataset": dataset,