From e68610acce941846b1a7784105d32ff3a3b67f0f Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Wed, 4 Nov 2020 16:26:12 +0300 Subject: [PATCH 1/5] prefetch mvp --- catalyst/data/__init__.py | 5 +- catalyst/data/loader.py | 161 +++++++++++++++++++++++++++++--------- catalyst/utils/torch.py | 2 +- 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/catalyst/data/__init__.py b/catalyst/data/__init__.py index 70c382009f..3a65f834ad 100644 --- a/catalyst/data/__init__.py +++ b/catalyst/data/__init__.py @@ -10,7 +10,10 @@ MetricLearningTrainDataset, QueryGalleryDataset, ) -from catalyst.data.loader import BatchLimitLoaderWrapper +from catalyst.data.loader import ( + BatchLimitLoaderWrapper, + BatchPrefetchLoaderWrapper, +) from catalyst.data.reader import ( ReaderSpec, ScalarReader, diff --git a/catalyst/data/loader.py b/catalyst/data/loader.py index 02a7c54ce9..94ceb46c8d 100644 --- a/catalyst/data/loader.py +++ b/catalyst/data/loader.py @@ -1,9 +1,53 @@ -from typing import Union +from typing import Any, Callable, Iterable, Union +import queue +import sys +import threading +import numpy as np + +import torch from torch.utils.data import DataLoader -class BatchLimitLoaderWrapper: +class ILoaderWrapper: + def __init__(self, loader: DataLoader): + self.origin = loader + + def __getattr__(self, key): + """ + Gets attribute by ``key``. + Firstly, looks at the ``origin`` for the appropriate ``key``. + If none founds - looks at the wrappers attributes. + If could not found anything - raises ``NotImplementedError``. + + Args: + key: attribute key + + Returns: + attribute value + + Raises: + NotImplementedError: if could not find attribute in ``origin`` + or ``wrapper`` + """ + value = getattr(self.origin, key, None) + if value is not None: + return value + value = getattr(self, key, None) + if value is not None: + return value + raise NotImplementedError() + + def __len__(self) -> int: + """Returns length of the wrapper loader. + + Returns: + int: length of the wrapper loader + """ + return len(self.origin) + + +class BatchLimitLoaderWrapper(ILoaderWrapper): """ Loader wrapper. Limits number of batches used per each iteration. @@ -50,6 +94,7 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]): num_batches (Union[int, float]): number of batches to use (int), or portion of iterator (float, should be in [0;1] range) """ + super().__init__(loader) assert isinstance(num_batches, (int, float)), ( "Expected ``num_batches`` type is int/float" f"but got {type(num_batches)}" @@ -61,36 +106,10 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]): ) num_batches = int(len(loader) * num_batches) - self.origin = loader self.iterator = iter(self.origin) self.iteration_index = 0 self.num_batches = num_batches - def __getattr__(self, key): - """ - Gets attribute by ``key``. - Firstly, looks at the ``origin`` for the appropriate ``key``. - If none founds - looks at the wrappers attributes. - If could not found anything - raises ``NotImplementedError``. - - Args: - key: attribute key - - Returns: - attribute value - - Raises: - NotImplementedError: if could not find attribute in ``origin`` - or ``wrapper`` - """ - value = getattr(self.origin, key, None) - if value is not None: - return value - value = getattr(self, key, None) - if value is not None: - return value - raise NotImplementedError() - def __iter__(self): """Iterator. @@ -115,13 +134,85 @@ def __next__(self): batch = next(self.iterator) return batch - def __len__(self) -> int: - """Returns length of the wrapper loader. - Returns: - int: length of the wrapper loader - """ - return len(self.origin) +def _any2cuda_non_blocking(value: Any): + # based on catalyst.utils.torch.any2device + # but with cuda non_blocking trick + if isinstance(value, dict): + return {k: _any2cuda_non_blocking(v) for k, v in value.items()} + elif isinstance(value, (tuple, list)): + return [_any2cuda_non_blocking(v) for v in value] + elif torch.is_tensor(value): + return value.cuda(non_blocking=True) + elif ( + isinstance(value, (np.ndarray, np.void)) + and value.dtype.fields is not None + ): + return { + k: _any2cuda_non_blocking(value[k]) + for k in value.dtype.fields.keys() + } + elif isinstance(value, np.ndarray): + return torch.tensor(value).cuda(non_blocking=True) + + +def _map_loop( + func: Callable, + iterable: Iterable, + result_queue: queue.Queue, + error_queue: queue.Queue, + done_event: threading.Event, +): + try: + for x in iterable: + result = func(x) + result_queue.put(result) + except BaseException: + error_queue.put(sys.exc_info()) + finally: + done_event.set() + + +def _prefetch_map( + func: Callable, + iterable: Iterable, + num_prefetches: int = 1, + timeout: int = 2, +) -> Iterable: + result_queue = queue.Queue(num_prefetches) + error_queue = queue.Queue(1) + done_event = threading.Event() + map_thread = threading.Thread( + target=_map_loop, + args=(func, iterable, result_queue, error_queue, done_event), + ) + map_thread.daemon = True + map_thread.start() + while not (done_event.is_set() and result_queue.empty()): + try: + result = result_queue.get(timeout=timeout) + except queue.Empty: + continue + yield result + if error_queue.full(): + raise error_queue.get()[1] + + +def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable: + if torch.cuda.is_available(): + loader = _prefetch_map( + _any2cuda_non_blocking, loader, num_prefetches=num_prefetches, + ) + return loader + + +class BatchPrefetchLoaderWrapper(ILoaderWrapper): + def __init__(self, loader: DataLoader, num_prefetches: int = None): + super().__init__(loader) + self.num_prefetches = num_prefetches or loader.batch_size + + def __iter__(self): + return _prefetch_loader(self.origin, self.num_prefetches) -__all__ = ["BatchLimitLoaderWrapper"] +__all__ = ["BatchLimitLoaderWrapper", "BatchPrefetchLoaderWrapper"] diff --git a/catalyst/utils/torch.py b/catalyst/utils/torch.py index 0372f7bf69..407ab55c15 100644 --- a/catalyst/utils/torch.py +++ b/catalyst/utils/torch.py @@ -143,7 +143,7 @@ def any2device(value, device: Device): k: any2device(value[k], device) for k in value.dtype.fields.keys() } elif isinstance(value, np.ndarray): - return torch.Tensor(value).to(device) + return torch.tensor(value, device=device) return value From 3cda19cc08895166e2edb732a6a66ecc7d189721 Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Sun, 8 Nov 2020 11:00:21 +0300 Subject: [PATCH 2/5] docs draft --- catalyst/data/loader.py | 80 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/catalyst/data/loader.py b/catalyst/data/loader.py index 94ceb46c8d..f9de13d111 100644 --- a/catalyst/data/loader.py +++ b/catalyst/data/loader.py @@ -207,6 +207,86 @@ def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable: class BatchPrefetchLoaderWrapper(ILoaderWrapper): + """ + + Base usage: + + .. code-block:: python + + import torch + from torch.utils.data import DataLoader, TensorDataset + from catalyst.data import BatchPrefetchLoaderWrapper + + num_samples, num_features = int(1e4), int(1e1) + X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) + dataset = TensorDataset(X, y) + loader = DataLoader(dataset, batch_size=32, num_workers=1) + loader = BatchPrefetchLoaderWrapper(loader) + + Minimal working example: + + .. code-block:: python + + import os + import torch + from torch.nn import functional as F + from torch.utils.data import DataLoader + from catalyst import dl, metrics + from catalyst.data.cv import ToTensor + from catalyst.contrib.datasets import MNIST + from catalyst.data import BatchPrefetchLoaderWrapper + + class CustomRunner(dl.Runner): + + def predict_batch(self, batch): + # model inference step + return self.model(batch[0].to(self.device).view(batch[0].size(0), -1)) + + def _handle_batch(self, batch): + # model train/valid step + x, y = batch + y_hat = self.model(x.view(x.size(0), -1)) + + loss = F.cross_entropy(y_hat, y) + accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) + self.batch_metrics.update( + {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} + ) + + if self.is_train_loader: + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + model = torch.nn.Linear(28 * 28, 10) + optimizer = torch.optim.Adam(model.parameters(), lr=0.02) + + batch_size=32 + loaders = { + "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=batch_size), + "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=batch_size), + } + loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()} + + runner = CustomRunner() + # model training + runner.train( + model=model, + optimizer=optimizer, + loaders=loaders, + logdir="./logs", + num_epochs=5, + verbose=True, + load_best_on_end=True, + ) + # model inference + for prediction in runner.predict_loader(loader=loaders["valid"]): + assert prediction.detach().cpu().numpy().shape[-1] == 10 + # model tracing + traced_model = runner.trace(loader=loaders["valid"]) + + """ + def __init__(self, loader: DataLoader, num_prefetches: int = None): super().__init__(loader) self.num_prefetches = num_prefetches or loader.batch_size From 7035bc10a19ebc9a0a88d6c188228eafe14a3a4f Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Tue, 10 Nov 2020 09:02:38 +0300 Subject: [PATCH 3/5] cleanup --- .github/workflows/codestyle.yml | 2 +- catalyst/contrib/utils/tests/test_misc.py | 67 ----------------------- catalyst/core/runner.py | 3 +- catalyst/dl/scripts/run.py | 2 +- catalyst/dl/scripts/tune.py | 3 +- catalyst/runners/runner.py | 3 +- catalyst/utils/__init__.py | 3 +- catalyst/utils/loaders.py | 2 +- catalyst/utils/misc.py | 54 +++++++++++------- catalyst/utils/seed.py | 35 ------------ docs/api/utils.rst | 7 --- 11 files changed, 41 insertions(+), 140 deletions(-) delete mode 100644 catalyst/utils/seed.py diff --git a/.github/workflows/codestyle.yml b/.github/workflows/codestyle.yml index b871d796f8..79f9068baf 100644 --- a/.github/workflows/codestyle.yml +++ b/.github/workflows/codestyle.yml @@ -66,7 +66,7 @@ jobs: flake8 . | reviewdog -f=pep8 -reporter=github-pr-review build: - name: codestyle + name: codestyle-pytest-docs runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/catalyst/contrib/utils/tests/test_misc.py b/catalyst/contrib/utils/tests/test_misc.py index e97dd0c88f..fd9566f1d4 100644 --- a/catalyst/contrib/utils/tests/test_misc.py +++ b/catalyst/contrib/utils/tests/test_misc.py @@ -50,70 +50,3 @@ def forward(self, *, x, y=None): ) ) assert params_predicted == params_true - - -def test_fn_ends_with_pass(): - def useless_fn(): - pass - - def usefull_fn(): - print("I am useful!") - - assert utils.fn_ends_with_pass(useless_fn) is True - assert utils.fn_ends_with_pass(usefull_fn) is False - - -def test_fn_ends_with_pass_on_callbacks(): - def test_fn_ends_with_pass_on_callback( - callback, events, - ): - for event in events["covered"]: - fn_name = f"on_{event}" - assert ( - utils.fn_ends_with_pass(getattr(callback.__class__, fn_name)) - is False - ) - for event in events["non-covered"]: - fn_name = f"on_{event}" - assert ( - utils.fn_ends_with_pass(getattr(callback.__class__, fn_name)) - is True - ) - - # Callback test - from catalyst.dl import Callback - - callback = Callback(order=1) - start_events = [ - "stage_start", - "epoch_start", - "batch_start", - "loader_start", - ] - end_events = [ - "stage_end", - "epoch_end", - "batch_end", - "loader_end", - "exception", - ] - events = {"covered": [], "non-covered": [*start_events, *end_events]} - test_fn_ends_with_pass_on_callback(callback=callback, events=events) - - # CriterionCallback test - from catalyst.dl import CriterionCallback - - callback = CriterionCallback() - covered_events = ["stage_start", "batch_end"] - non_covered_start_events = ["epoch_start", "batch_start", "loader_start"] - non_covered_end_events = [ - "stage_end", - "epoch_end", - "loader_end", - "exception", - ] - events = { - "covered": [*covered_events], - "non-covered": [*non_covered_start_events, *non_covered_end_events], - } - test_fn_ends_with_pass_on_callback(callback=callback, events=events) diff --git a/catalyst/core/runner.py b/catalyst/core/runner.py index 48ca2ca725..825f419d4f 100644 --- a/catalyst/core/runner.py +++ b/catalyst/core/runner.py @@ -30,8 +30,7 @@ from catalyst.utils.components import process_components from catalyst.utils.distributed import get_rank from catalyst.utils.loaders import validate_loaders -from catalyst.utils.misc import maybe_recursive_call -from catalyst.utils.seed import set_global_seed +from catalyst.utils.misc import maybe_recursive_call, set_global_seed from catalyst.utils.torch import any2device diff --git a/catalyst/dl/scripts/run.py b/catalyst/dl/scripts/run.py index fc596b6031..726314520f 100755 --- a/catalyst/dl/scripts/run.py +++ b/catalyst/dl/scripts/run.py @@ -7,13 +7,13 @@ from catalyst.contrib.utils.argparse import boolean_flag from catalyst.utils.distributed import get_rank +from catalyst.utils.misc import set_global_seed from catalyst.utils.parser import parse_args_uargs from catalyst.utils.scripts import ( distributed_cmd_run, dump_code, prepare_config_api_components, ) -from catalyst.utils.seed import set_global_seed from catalyst.utils.sys import dump_environment from catalyst.utils.torch import prepare_cudnn diff --git a/catalyst/dl/scripts/tune.py b/catalyst/dl/scripts/tune.py index 38b7b4ed52..f6ecead43f 100755 --- a/catalyst/dl/scripts/tune.py +++ b/catalyst/dl/scripts/tune.py @@ -10,10 +10,9 @@ from catalyst.contrib.utils.argparse import boolean_flag from catalyst.utils.distributed import get_rank -from catalyst.utils.misc import maybe_recursive_call +from catalyst.utils.misc import maybe_recursive_call, set_global_seed from catalyst.utils.parser import parse_args_uargs from catalyst.utils.scripts import dump_code, prepare_config_api_components -from catalyst.utils.seed import set_global_seed from catalyst.utils.sys import dump_environment from catalyst.utils.torch import prepare_cudnn diff --git a/catalyst/runners/runner.py b/catalyst/runners/runner.py index 1764a98d00..115b810fd7 100644 --- a/catalyst/runners/runner.py +++ b/catalyst/runners/runner.py @@ -13,9 +13,8 @@ from catalyst.typing import Criterion, Device, Model, Optimizer, Scheduler from catalyst.utils.checkpoint import load_checkpoint, unpack_checkpoint from catalyst.utils.components import process_components -from catalyst.utils.misc import maybe_recursive_call +from catalyst.utils.misc import maybe_recursive_call, set_global_seed from catalyst.utils.scripts import distributed_cmd_run -from catalyst.utils.seed import set_global_seed from catalyst.utils.torch import ( get_device, get_requires_grad, diff --git a/catalyst/utils/__init__.py b/catalyst/utils/__init__.py index 0f247ff75d..23961197af 100644 --- a/catalyst/utils/__init__.py +++ b/catalyst/utils/__init__.py @@ -62,7 +62,7 @@ get_utcnow_time, is_exception, maybe_recursive_call, - fn_ends_with_pass, + set_global_seed, ) from catalyst.utils.numpy import get_one_hot from catalyst.utils.parser import parse_config_args, parse_args_uargs @@ -74,7 +74,6 @@ dump_experiment_code, distributed_cmd_run, ) -from catalyst.utils.seed import set_global_seed from catalyst.utils.swa import ( average_weights, get_averaged_weights_by_path_mask, diff --git a/catalyst/utils/loaders.py b/catalyst/utils/loaders.py index fbce0e1219..fd40a158fe 100644 --- a/catalyst/utils/loaders.py +++ b/catalyst/utils/loaders.py @@ -10,7 +10,7 @@ from catalyst.registry import SAMPLER from catalyst.utils.dict import merge_dicts from catalyst.utils.distributed import get_rank -from catalyst.utils.seed import set_global_seed +from catalyst.utils.misc import set_global_seed def get_loader( diff --git a/catalyst/utils/misc.py b/catalyst/utils/misc.py index d7824f258d..acca7cda14 100644 --- a/catalyst/utils/misc.py +++ b/catalyst/utils/misc.py @@ -2,8 +2,41 @@ from datetime import datetime import inspect from pathlib import Path +import random import shutil +import numpy as np +from packaging.version import parse, Version + + +def set_global_seed(seed: int) -> None: + """Sets random seed into Numpy and Random, PyTorch and TensorFlow. + + Args: + seed: random seed + """ + random.seed(seed) + np.random.seed(seed) + try: + import torch + except ImportError: + pass + else: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + try: + import tensorflow as tf + except ImportError: + pass + else: + if parse(tf.__version__) >= Version("2.0.0"): + tf.random.set_seed(seed) + elif parse(tf.__version__) <= Version("1.13.2"): + tf.set_random_seed(seed) + else: + tf.compat.v1.set_random_seed(seed) + def maybe_recursive_call( object_or_dict, @@ -142,25 +175,6 @@ def get_fn_argsnames(fn: Callable[..., Any], exclude: List[str] = None): return params -def fn_ends_with_pass(fn: Callable[..., Any]): - """ - Check that function end with pass statement - (probably does nothing in any way). - Mainly used to filter callbacks with empty on_{event} methods. - - Args: - fn (Callable[..., Any]): target Callable - - Returns: - bool: True if there is pass in the first indentation level of fn - and nothing happens before it, False in any other case. - """ - source_lines = inspect.getsourcelines(fn)[0] - if source_lines[-1].strip() == "pass": - return True - return False - - __all__ = [ "copy_directory", "format_metric", @@ -169,5 +183,5 @@ def fn_ends_with_pass(fn: Callable[..., Any]): "get_utcnow_time", "is_exception", "maybe_recursive_call", - "fn_ends_with_pass", + "set_global_seed", ] diff --git a/catalyst/utils/seed.py b/catalyst/utils/seed.py deleted file mode 100644 index e3116f8a99..0000000000 --- a/catalyst/utils/seed.py +++ /dev/null @@ -1,35 +0,0 @@ -import random - -import numpy as np -from packaging.version import parse, Version - - -def set_global_seed(seed: int) -> None: - """Sets random seed into PyTorch, TensorFlow, Numpy and Random. - - Args: - seed: random seed - """ - try: - import torch - except ImportError: - pass - else: - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - try: - import tensorflow as tf - except ImportError: - pass - else: - if parse(tf.__version__) >= Version("2.0.0"): - tf.random.set_seed(seed) - elif parse(tf.__version__) <= Version("1.13.2"): - tf.set_random_seed(seed) - else: - tf.compat.v1.set_random_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -__all__ = ["set_global_seed"] diff --git a/docs/api/utils.rst b/docs/api/utils.rst index aba018d9fc..48bedee611 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -112,13 +112,6 @@ Scripts :undoc-members: :show-inheritance: -Seed -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: catalyst.utils.seed - :members: - :undoc-members: - :show-inheritance: - Stochastic Weights Averaging (SWA) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: catalyst.utils.swa From 9f66c349bd3a6664c240e8df2152b7df3e5263c1 Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Tue, 10 Nov 2020 09:06:17 +0300 Subject: [PATCH 4/5] logs --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1ad206ac3..0a8793c7f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,16 +15,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `OneOf` and `OneOfV2` batch transforms ([#951](https://github.com/catalyst-team/catalyst/pull/951)) - ``precision_recall_fbeta_support`` metric ([#971](https://github.com/catalyst-team/catalyst/pull/971)) - Pruning tutorial ([#987](https://github.com/catalyst-team/catalyst/pull/987)) +- BatchPrefechLoaderWrapper ([#986](https://github.com/catalyst-team/catalyst/pull/986)) +- DynamicBalanceClassSampler ([#954](https://github.com/catalyst-team/catalyst/pull/954)) ### Changed - update Catalyst version to `20.10.1` for tutorials ([#967](https://github.com/catalyst-team/catalyst/pull/967)) - added link to dl-course ([#967](https://github.com/catalyst-team/catalyst/pull/967)) - docs were restructured ([#985](https://github.com/catalyst-team/catalyst/pull/985)) +- `set_global_seed` moved from `utils.seed` to `utils.misc` ([#986](https://github.com/catalyst-team/catalyst/pull/986)) ### Removed - several deprecated tutorials ([#967](https://github.com/catalyst-team/catalyst/pull/967)) +- several deprecated func from utils.misc ([#986](https://github.com/catalyst-team/catalyst/pull/986)) ### Fixed @@ -74,7 +78,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Runner registry support for Config API ([#936](https://github.com/catalyst-team/catalyst/pull/936)) - - `catalyst-dl tune` command - Optuna with Config API integration for AutoML hyperparameters optimization ([#937](https://github.com/catalyst-team/catalyst/pull/937)) - `OptunaPruningCallback` alias for `OptunaCallback` ([#937](https://github.com/catalyst-team/catalyst/pull/937)) - AdamP and SGDP to `catalyst.contrib.nn.criterion` ([#942](https://github.com/catalyst-team/catalyst/pull/942)) @@ -90,7 +93,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Logging double logging :) ([#936](https://github.com/catalyst-team/catalyst/pull/936)) - - CMCCallback ([#941](https://github.com/catalyst-team/catalyst/pull/941)) ## [20.09] - 2020-09-07 From fd03a8bffe78222ed81cd669f174dcf6a670dbca Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Tue, 10 Nov 2020 09:41:38 +0300 Subject: [PATCH 5/5] docs --- CHANGELOG.md | 2 +- catalyst/data/__init__.py | 1 + catalyst/data/loader.py | 68 ++++++++++++++++++++++++++------------- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a8793c7f2..55f0c35bc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `OneOf` and `OneOfV2` batch transforms ([#951](https://github.com/catalyst-team/catalyst/pull/951)) - ``precision_recall_fbeta_support`` metric ([#971](https://github.com/catalyst-team/catalyst/pull/971)) - Pruning tutorial ([#987](https://github.com/catalyst-team/catalyst/pull/987)) -- BatchPrefechLoaderWrapper ([#986](https://github.com/catalyst-team/catalyst/pull/986)) +- BatchPrefetchLoaderWrapper ([#986](https://github.com/catalyst-team/catalyst/pull/986)) - DynamicBalanceClassSampler ([#954](https://github.com/catalyst-team/catalyst/pull/954)) ### Changed diff --git a/catalyst/data/__init__.py b/catalyst/data/__init__.py index 8aa60e2b95..acb49cd881 100644 --- a/catalyst/data/__init__.py +++ b/catalyst/data/__init__.py @@ -11,6 +11,7 @@ QueryGalleryDataset, ) from catalyst.data.loader import ( + ILoaderWrapper, BatchLimitLoaderWrapper, BatchPrefetchLoaderWrapper, ) diff --git a/catalyst/data/loader.py b/catalyst/data/loader.py index f9de13d111..70e79e7ec9 100644 --- a/catalyst/data/loader.py +++ b/catalyst/data/loader.py @@ -11,6 +11,11 @@ class ILoaderWrapper: def __init__(self, loader: DataLoader): + """Loader wrapper interface. + + Args: + loader: torch dataloader. + """ self.origin = loader def __getattr__(self, key): @@ -48,8 +53,7 @@ def __len__(self) -> int: class BatchLimitLoaderWrapper(ILoaderWrapper): - """ - Loader wrapper. Limits number of batches used per each iteration. + """Loader wrapper. Limits number of batches used per each iteration. For example, if you have some loader and want to use only first 5 bathes: @@ -86,8 +90,7 @@ class BatchLimitLoaderWrapper(ILoaderWrapper): """ def __init__(self, loader: DataLoader, num_batches: Union[int, float]): - """ - Loader wrapper. Limits number of batches used per each iteration. + """Loader wrapper. Limits number of batches used per each iteration. Args: loader: torch dataloader. @@ -167,7 +170,7 @@ def _map_loop( for x in iterable: result = func(x) result_queue.put(result) - except BaseException: + except BaseException: # noqa: WPS424 error_queue.put(sys.exc_info()) finally: done_event.set() @@ -200,14 +203,15 @@ def _prefetch_map( def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable: if torch.cuda.is_available(): - loader = _prefetch_map( + return _prefetch_map( _any2cuda_non_blocking, loader, num_prefetches=num_prefetches, ) - return loader + else: + return iter(loader) class BatchPrefetchLoaderWrapper(ILoaderWrapper): - """ + """Loader wrapper. Prefetches specified number of batches on the GPU. Base usage: @@ -238,19 +242,15 @@ class BatchPrefetchLoaderWrapper(ILoaderWrapper): class CustomRunner(dl.Runner): - def predict_batch(self, batch): - # model inference step - return self.model(batch[0].to(self.device).view(batch[0].size(0), -1)) - def _handle_batch(self, batch): # model train/valid step x, y = batch y_hat = self.model(x.view(x.size(0), -1)) loss = F.cross_entropy(y_hat, y) - accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) + accuracy01 = metrics.accuracy(y_hat, y, topk=(1, )) self.batch_metrics.update( - {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} + {"loss": loss, "accuracy01": accuracy01} ) if self.is_train_loader: @@ -263,10 +263,26 @@ def _handle_batch(self, batch): batch_size=32 loaders = { - "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=batch_size), - "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=batch_size), + "train": DataLoader( + MNIST( + os.getcwd(), + train=True, + download=True, + transform=ToTensor() + ), + batch_size=batch_size), + "valid": DataLoader( + MNIST( + os.getcwd(), + train=False, + download=True, + transform=ToTensor() + ), + batch_size=batch_size), + } + loaders = { + k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items() } - loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()} runner = CustomRunner() # model training @@ -279,19 +295,25 @@ def _handle_batch(self, batch): verbose=True, load_best_on_end=True, ) - # model inference - for prediction in runner.predict_loader(loader=loaders["valid"]): - assert prediction.detach().cpu().numpy().shape[-1] == 10 - # model tracing - traced_model = runner.trace(loader=loaders["valid"]) """ def __init__(self, loader: DataLoader, num_prefetches: int = None): + """Loader wrapper. Prefetches specified number of batches on the GPU. + + Args: + loader: torch dataloader. + num_prefetches: number of batches to prefetch on the GPU. + """ super().__init__(loader) - self.num_prefetches = num_prefetches or loader.batch_size + self.num_prefetches = num_prefetches or 1 def __iter__(self): + """Iterator. + + Returns: + iterator object + """ return _prefetch_loader(self.origin, self.num_prefetches)