Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config API + Optuna PoC = AutoML 🖤 #937

Merged
merged 13 commits into from Sep 23, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Runner registry support for Config API ([#936](https://github.com/catalyst-team/catalyst/pull/936))
- `catalyst-dl tune` command - Optuna with Config API integration for AutoML hyperparameters optimization ([#937](https://github.com/catalyst-team/catalyst/pull/937))
- `OptunaPruningCallback` alias for `OptunaCallback` ([#937](https://github.com/catalyst-team/catalyst/pull/937))

### Changed

Expand Down
5 changes: 4 additions & 1 deletion catalyst/contrib/dl/callbacks/__init__.py
Expand Up @@ -93,7 +93,10 @@

try:
import optuna
from catalyst.contrib.dl.callbacks.optuna_callback import OptunaCallback
from catalyst.contrib.dl.callbacks.optuna_callback import (
OptunaPruningCallback,
OptunaCallback,
)
except ImportError as ex:
if settings.optuna_required:
logger.warning(
Expand Down
45 changes: 36 additions & 9 deletions catalyst/contrib/dl/callbacks/optuna_callback.py
Expand Up @@ -3,7 +3,7 @@
from catalyst.core import Callback, CallbackOrder, IRunner


class OptunaCallback(Callback):
class OptunaPruningCallback(Callback):
"""
Optuna callback for pruning unpromising runs

Expand All @@ -25,7 +25,7 @@ def objective(trial: optuna.Trial):
criterion=criterion,
optimizer=optimizer,
callbacks=[
OptunaCallback(trial)
OptunaPruningCallback(trial)
# some other callbacks ...
],
num_epochs=num_epochs,
Expand All @@ -35,34 +35,61 @@ def objective(trial: optuna.Trial):
study = optuna.create_study()
study.optimize(objective, n_trials=100, timeout=600)

Config API is not supported.
Config API is supported through `catalyst-dl tune` command.
"""

def __init__(self, trial: optuna.Trial):
def __init__(self, trial: optuna.Trial = None):
"""
This callback can be used for early stopping (pruning)
unpromising runs.

Args:
trial: Optuna.Trial for experiment.
"""
super(OptunaCallback, self).__init__(CallbackOrder.External)
super().__init__(CallbackOrder.External)
self.trial = trial

def on_epoch_end(self, runner: "IRunner"):
def on_stage_start(self, runner: "IRunner"):
"""
On epoch end action.
On stage start hook.
Takes ``optuna_trial`` from ``Experiment`` for future usage if needed.

Considering prune or not to prune current run at current epoch.
Args:
runner: runner for current experiment

Raises:
TrialPruned: if current run should be pruned
NotImplementedError: if no Optuna trial was found on stage start.
"""
trial = runner.experiment.trial
if (
self.trial is None
and trial is not None
and isinstance(trial, optuna.Trial)
):
self.trial = trial

if self.trial is None:
raise NotImplementedError("No Optuna trial found for logging")

def on_epoch_end(self, runner: "IRunner"):
"""
On epoch end hook.

Considering prune or not to prune current run at current epoch.

Args:
runner: runner for current experiment

Raises:
TrialPruned: if current run should be pruned
"""
metric_value = runner.valid_metrics[runner.main_metric]
self.trial.report(metric_value, step=runner.epoch)
if self.trial.should_prune():
message = "Trial was pruned at epoch {}.".format(runner.epoch)
raise optuna.TrialPruned(message)


OptunaCallback = OptunaPruningCallback

__all__ = ["OptunaCallback", "OptunaPruningCallback"]
4 changes: 2 additions & 2 deletions catalyst/contrib/dl/callbacks/tests/test_optuna_callback.py
Expand Up @@ -7,7 +7,7 @@

from catalyst import dl
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.dl.callbacks import OptunaCallback
from catalyst.contrib.dl.callbacks import OptunaPruningCallback
from catalyst.contrib.nn import Flatten
from catalyst.data.cv.transforms.torch import ToTensor
from catalyst.dl import AccuracyCallback
Expand Down Expand Up @@ -39,7 +39,7 @@ def objective(trial):
criterion=criterion,
optimizer=optimizer,
callbacks=[
OptunaCallback(trial),
OptunaPruningCallback(trial),
AccuracyCallback(num_classes=10),
],
num_epochs=10,
Expand Down
34 changes: 25 additions & 9 deletions catalyst/core/experiment.py
Expand Up @@ -63,7 +63,8 @@ def logdir(self) -> str:
@property
@abstractmethod
def hparams(self) -> OrderedDict:
"""Returns hyper-parameters
"""
Returns hyper-parameters for current experiment.

Example::
>>> experiment.hparams
Expand All @@ -78,17 +79,16 @@ def hparams(self) -> OrderedDict:

@property
@abstractmethod
def stages(self) -> Iterable[str]:
"""Experiment's stage names.
def trial(self) -> Any:
"""
Returns hyperparameter trial for current experiment.
Could be usefull for Optuna/HyperOpt/Ray.tune
hyperparameters optimizers.

Example::

>>> experiment.stages
["pretraining", "training", "finetuning"]

.. note::
To understand stages concept, please follow Catalyst documentation,
for example, :py:mod:`catalyst.core.callback.Callback`
>>> experiment.trial
optuna.trial._trial.Trial # Optuna variant
"""
pass

Expand All @@ -113,6 +113,22 @@ def distributed_params(self) -> Dict:
"""
pass

@property
@abstractmethod
def stages(self) -> Iterable[str]:
"""Experiment's stage names.

Example::

>>> experiment.stages
["pretraining", "training", "finetuning"]

.. note::
To understand stages concept, please follow Catalyst documentation,
for example, :py:mod:`catalyst.core.callback.Callback`
"""
pass

@abstractmethod
def get_stage_params(self, stage: str) -> Mapping[str, Any]:
"""Returns extra stage parameters for a given stage.
Expand Down
6 changes: 5 additions & 1 deletion catalyst/dl/__main__.py
Expand Up @@ -3,7 +3,7 @@

from catalyst.__version__ import __version__
from catalyst.dl.scripts import quantize, run, trace
from catalyst.tools.settings import IS_GIT_AVAILABLE
from catalyst.tools.settings import IS_GIT_AVAILABLE, IS_OPTUNA_AVAILABLE

COMMANDS = OrderedDict(
[("run", run), ("trace", trace), ("quantize", quantize)]
Expand All @@ -12,6 +12,10 @@
from catalyst.dl.scripts import init

COMMANDS["init"] = init
if IS_OPTUNA_AVAILABLE:
from catalyst.dl.scripts import tune

COMMANDS["tune"] = tune


def build_parser() -> ArgumentParser:
Expand Down
49 changes: 32 additions & 17 deletions catalyst/dl/experiment/config.py
Expand Up @@ -63,6 +63,7 @@ def __init__(self, config: Dict):
config (dict): dictionary with parameters
"""
self._config: Dict = deepcopy(config)
self._trial = None
self._initial_seed: int = self._config.get("args", {}).get("seed", 42)
self._verbose: bool = self._config.get("args", {}).get(
"verbose", False
Expand All @@ -77,7 +78,7 @@ def __init__(self, config: Dict):
"overfit", False
)

self.__prepare_logdir()
self._prepare_logdir()

self._config["stages"]["stage_params"] = utils.merge_dicts(
deepcopy(
Expand All @@ -91,7 +92,13 @@ def __init__(self, config: Dict):
self._config["stages"]
)

def __prepare_logdir(self): # noqa: WPS112
def _get_logdir(self, config: Dict) -> str:
timestamp = utils.get_utcnow_time()
config_hash = utils.get_short_hash(config)
logdir = f"{timestamp}.{config_hash}"
return logdir

def _prepare_logdir(self): # noqa: WPS112
exclude_tag = "none"

logdir = self._config.get("args", {}).get("logdir", None)
Expand All @@ -105,11 +112,6 @@ def __prepare_logdir(self): # noqa: WPS112
else:
self._logdir = None

@property
def hparams(self) -> OrderedDict:
"""Returns hyperparameters"""
return OrderedDict(self._config)

def _get_stages_config(self, stages_config: Dict):
stages_defaults = {}
stages_config_out = OrderedDict()
Expand Down Expand Up @@ -147,12 +149,6 @@ def _get_stages_config(self, stages_config: Dict):

return stages_config_out

def _get_logdir(self, config: Dict) -> str:
timestamp = utils.get_utcnow_time()
config_hash = utils.get_short_hash(config)
logdir = f"{timestamp}.{config_hash}"
return logdir

@property
def initial_seed(self) -> int:
"""Experiment's initial seed value."""
Expand All @@ -164,16 +160,35 @@ def logdir(self):
return self._logdir

@property
def stages(self) -> List[str]:
"""Experiment's stage names."""
stages_keys = list(self.stages_config.keys())
return stages_keys
def hparams(self) -> OrderedDict:
"""Returns hyperparameters"""
return OrderedDict(self._config)

@property
def trial(self) -> Any:
"""
Returns hyperparameter trial for current experiment.
Could be usefull for Optuna/HyperOpt/Ray.tune
hyperparameters optimizers.

Example::

>>> experiment.trial
optuna.trial._trial.Trial # Optuna variant
"""
return self._trial

@property
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 methond."""
return self._config.get("distributed_params", {})

@property
def stages(self) -> List[str]:
"""Experiment's stage names."""
stages_keys = list(self.stages_config.keys())
return stages_keys

def get_stage_params(self, stage: str) -> Mapping[str, Any]:
"""Returns the state parameters for a given stage."""
return self.stages_config[stage].get("stage_params", {})
Expand Down
29 changes: 24 additions & 5 deletions catalyst/dl/experiment/experiment.py
Expand Up @@ -44,6 +44,7 @@ def __init__(
criterion: Criterion = None,
optimizer: Optimizer = None,
scheduler: Scheduler = None,
trial: Any = None,
num_epochs: int = 1,
valid_loader: str = "valid",
main_metric: str = "loss",
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__(
criterion (Criterion): criterion function
optimizer (Optimizer): optimizer
scheduler (Scheduler): scheduler
trial : hyperparameters optimization trial.
Used for integrations with Optuna/HyperOpt/Ray.tune.
num_epochs (int): number of experiment's epochs
valid_loader (str): loader name used to calculate
the metrics and save the checkpoints. For example,
Expand Down Expand Up @@ -118,6 +121,8 @@ def __init__(
self._optimizer = optimizer
self._scheduler = scheduler

self._trial = trial

self._initial_seed = initial_seed
self._logdir = logdir
self._stage = stage
Expand Down Expand Up @@ -147,11 +152,6 @@ def stages(self) -> Iterable[str]:
"""Experiment's stage names (array with one value)."""
return [self._stage]

@property
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 method."""
return self._distributed_params

@property
def hparams(self) -> OrderedDict:
"""Returns hyper parameters"""
Expand All @@ -169,6 +169,25 @@ def hparams(self) -> OrderedDict:
hparams[f"{k}_batch_size"] = v.batch_size
return hparams

@property
def trial(self) -> Any:
"""
Returns hyperparameter trial for current experiment.
Could be usefull for Optuna/HyperOpt/Ray.tune
hyperparameters optimizers.

Example::

>>> experiment.trial
optuna.trial._trial.Trial # Optuna variant
"""
return self._trial

@property
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 method."""
return self._distributed_params

@staticmethod
def _get_loaders(
loaders: "OrderedDict[str, DataLoader]",
Expand Down