Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utils callbacks #1127

Merged
merged 18 commits into from
Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions catalyst/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EarlyStoppingCallback,
)
from catalyst.callbacks.optimizer import IOptimizerCallback, OptimizerCallback
from catalyst.callbacks.tracing import TracingCallback
Scitator marked this conversation as resolved.
Show resolved Hide resolved

if SETTINGS.optuna_required:
from catalyst.callbacks.optuna import OptunaPruningCallback
Expand All @@ -39,6 +40,12 @@
LRFinder,
)

if SETTINGS.quantization_required:
from catalyst.callbacks.quantization import QuantizationCallback

if SETTINGS.onnx_required:
from catalyst.callbacks.onnx import OnnxCallback

elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
# from catalyst.callbacks.tracing import TracingCallback


Expand Down
103 changes: 103 additions & 0 deletions catalyst/callbacks/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Dict, Iterable, List, TYPE_CHECKING, Union
from pathlib import Path

from torch import Tensor

from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils import onnx_export

if TYPE_CHECKING:
from catalyst.core import IRunner


class OnnxCallback(Callback):
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
Callback for converting model to onnx runtime.

Args:
logdir: path to folder for saving
filename: filename
batch: input tensor for model. If None will take batch from train loader.
method_name (str, optional): Forward pass method to be converted. Defaults to "forward".
input_names (Iterable, optional): name of inputs in graph. Defaults to None.
output_names (List[str], optional): name of outputs in graph. Defaults to None.
dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes
with dynamic shapes. Defaults to None.
opset_version (int, optional): Defaults to 9.
do_constant_folding (bool, optional): If True, the constant-folding optimization
is applied to the model during export. Defaults to False.
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
"""

def __init__(
self,
logdir: Union[str, Path] = None,
filename: str = "onnx.py",
batch: Tensor = None,
method_name: str = "forward",
input_names: Iterable = None,
output_names: List[str] = None,
dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None,
opset_version: int = 9,
do_constant_folding: bool = False,
verbose: bool = False,
):
"""
Callback for converting model to onnx runtime.

Args:
logdir: path to folder for saving
filename: filename
batch: input tensor for model. If None will take batch from train loader.
method_name (str, optional): Forward pass method to be converted.
Defaults to "forward".
input_names (Iterable, optional): name of inputs in graph. Defaults to None.
output_names (List[str], optional): name of outputs in graph. Defaults to None.
dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes
with dynamic shapes. Defaults to None.
opset_version (int, optional): Defaults to 9.
do_constant_folding (bool, optional): If True, the constant-folding optimization
is applied to the model during export. Defaults to False.
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
"""
super().__init__(order=CallbackOrder.ExternalExtra, node=CallbackNode.Master)
if self.logdir is not None:
self.filename = Path(logdir) / filename
else:
self.filename = filename

self.method_name = method_name
self.input_names = input_names
self.output_names = output_names
self.dynamic_axes = dynamic_axes
self.opset_version = opset_version
self.do_constant_folding = do_constant_folding
self.verbose = verbose
self.batch = batch

def on_stage_end(self, runner: "IRunner") -> None:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
On stage end action.

Args:
runner: runner for experiment
"""
model = runner.model.cpu()
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
batch = self.batch or next(iter(runner.loaders["train"]))
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
onnx_export(
model=model,
file=self.filename,
batch=batch,
method_name=self.method_name,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
opset_version=self.opset_version,
do_constant_folding=self.do_constant_folding,
verbose=self.verbose,
)


__all__ = ["OnnxCallback"]
211 changes: 61 additions & 150 deletions catalyst/callbacks/quantization.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,61 @@
# # @TODO: make the same API for tracing/onnx/pruning/quantization
# from typing import Dict, Optional, Set, TYPE_CHECKING, Union
# from pathlib import Path
#
# import torch
# from torch import quantization
#
# from catalyst.core.callback import Callback, CallbackOrder
# from catalyst.utils.quantization import save_quantized_model
#
# if TYPE_CHECKING:
# from catalyst.core.runner import IRunner
#
#
# class DynamicQuantizationCallback(Callback):
# """Dynamic Quantization Callback
#
# This callback applying dynamic quantization to the model.
# """
#
# def __init__(
# self,
# metric: str = "loss",
# minimize: bool = True,
# min_delta: float = 1e-6,
# mode: str = "best",
# do_once: bool = True,
# qconfig_spec: Optional[Union[Set, Dict]] = None,
# dtype: Optional[torch.dtype] = torch.qint8,
# out_dir: Union[str, Path] = None,
# out_model: Union[str, Path] = None,
# backend: str = None,
# ):
# """Init method for callback
#
# Args:
# metric: Metric key we should trace model based on
# minimize: Whether do we minimize metric or not
# min_delta: Minimum value of change for metric to be
# considered as improved
# mode: One of `best` or `last`
# do_once: Whether do we trace once per stage or every epoch
# qconfig_spec: torch.quantization.quantize_dynamic
# parameter, you can define layers to be quantize
# dtype: type of the model parameters, default int8
# out_dir (Union[str, Path]): Directory to save model to
# out_model (Union[str, Path]): Path to save model to
# (overrides `out_dir` argument)
# backend: defines backend for quantization
# """
# super().__init__(order=CallbackOrder.external)
#
# if mode not in ["best", "last"]:
# raise ValueError(
# f"Unknown `mode` '{mode}'. " f"Must be 'best' or 'last'"
# )
#
# self.metric = metric
# self.mode = mode
# self.do_once = do_once
# self.best_score = None
# self.is_better = None
# self.first_time = True
# if minimize:
# self.is_better = lambda score, best: score <= (best - min_delta)
# else:
# self.is_better = lambda score, best: score >= (best + min_delta)
#
# self.opt_level = None
#
# if out_model is not None:
# out_model = Path(out_model)
# self.out_model = out_model
#
# if out_dir is not None:
# out_dir = Path(out_dir)
# self.out_dir = out_dir
# self.qconfig_spec = qconfig_spec
# self.dtype = dtype
#
# if backend is not None:
# torch.backends.quantized.engine = backend
#
# def on_epoch_end(self, runner: "IRunner"):
# """
# Performing model quantization on epoch end if condition metric is
# improved
#
# Args:
# runner: current runner
# """
# if not self.do_once:
# if self.mode == "best":
# score = runner.valid_metrics[self.metric]
#
# if self.best_score is None:
# self.best_score = score
#
# if self.is_better(score, self.best_score) or self.first_time:
# self.best_score = score
# quantized_model = quantization.quantize_dynamic(
# runner.model.cpu(),
# qconfig_spec=self.qconfig_spec,
# dtype=self.dtype,
# )
# save_quantized_model(
# model=quantized_model,
# logdir=runner.logdir,
# checkpoint_name=self.mode,
# out_model=self.out_model,
# out_dir=self.out_dir,
# )
# self.first_time = False
# else:
# quantized_model = quantization.quantize_dynamic(
# runner.model.cpu(),
# qconfig_spec=self.qconfig_spec,
# dtype=self.dtype,
# )
# save_quantized_model(
# model=quantized_model,
# logdir=runner.logdir,
# checkpoint_name=self.mode,
# out_model=self.out_model,
# out_dir=self.out_dir,
# )
#
# def on_stage_end(self, runner: "IRunner") -> None:
# """
# On stage end action.
#
# Args:
# runner: runner of your experiment
# """
# if self.do_once:
# quantized_model = quantization.quantize_dynamic(
# runner.model.cpu(),
# qconfig_spec=self.qconfig_spec,
# dtype=self.dtype,
# )
# save_quantized_model(
# model=quantized_model,
# logdir=runner.logdir,
# checkpoint_name=self.mode,
# out_model=self.out_model,
# out_dir=self.out_dir,
# )
#
#
# __all__ = ["DynamicQuantizationCallback"]
from typing import Dict, Optional, TYPE_CHECKING, Union
from pathlib import Path

import torch

from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils import quantize_model

if TYPE_CHECKING:
from catalyst.core import IRunner


class QuantizationCallback(Callback):
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
"""
Callback for model quantiztion.

Args:
logdir: path to folder for saving
filename: filename
qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None.
dtype (Union[str, Optional[torch.dtype]], optional): Type of weights after quantization.
Defaults to "qint8".
"""

def __init__(
self,
logdir: Union[str, Path] = None,
filename: str = "quantized.pth",
qconfig_spec: Dict = None,
dtype: Union[str, Optional[torch.dtype]] = "qint8",
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS355 Found an unnecessary blank line before a bracket

"""
Callback for model quantiztion.

Args:
logdir: path to folder for saving
filename: filename
qconfig_spec (Dict, optional): quantization config in PyTorch format.
Defaults to None.
dtype (Union[str, Optional[torch.dtype]], optional):
Type of weights after quantization.
Defaults to "qint8".
"""
super().__init__(
order=CallbackOrder.ExternalExtra, node=CallbackNode.master
) # External Extra for applying
# after CheckpointCallback; node master for saving.
self.qconfig_spec = qconfig_spec
self.dtype = dtype
if logdir is not None:
self.filename = Path(logdir) / filename
else:
self.filename = filename

def on_stage_end(self, runner: "IRunner") -> None:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
model = runner.model.cpu()
q_model = quantize_model(model.cpu(), qconfig_spec=self.qconfig_spec, dtype=self.dtype)
torch.save(q_model.state_dict(), self.filename)
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved


__all__ = ["QuantizationCallback"]