Skip to content

Commit

Permalink
Make the experiment manager load hyperparameters from a python file i…
Browse files Browse the repository at this point in the history
…nstead of a yaml file.
  • Loading branch information
ernestum committed Nov 18, 2022
1 parent d1ab5cf commit ff4ea0f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
38 changes: 20 additions & 18 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import importlib
import os
import pickle as pkl
import time
import warnings
from collections import OrderedDict
from pathlib import Path
from pprint import pprint
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -93,22 +93,16 @@ def __init__(
n_eval_envs: int = 1,
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
yaml_file: Optional[str] = None,
config_file: Optional[str] = None,
show_progress: bool = False,
):
super().__init__()
self.algo = algo
self.env_name = EnvironmentName(env_id)
# Custom params
self.custom_hyperparams = hyperparams
if (Path(__file__).parent / "hyperparams").is_dir():
# Package version
default_path = Path(__file__).parent
else:
# Take the root folder
default_path = Path(__file__).parent.parent

self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml")
self.config_file = config_file or f"rl_zoo3.hyperparams.{self.algo}"
self.env_kwargs = {} if env_kwargs is None else env_kwargs
self.n_timesteps = n_timesteps
self.normalize = False
Expand Down Expand Up @@ -282,15 +276,16 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:

def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
print(f"Loading hyperparameters from: {self.yaml_file}")
with open(self.yaml_file) as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
print(f"Loading hyperparameters from: {self.config_file}")
configurations = importlib.import_module(self.config_file)

hyperparams_dict = configurations.hyperparams
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")

if self.custom_hyperparams is not None:
# Overwrite hyperparams if needed
Expand Down Expand Up @@ -333,9 +328,14 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
# that can be evaluated as python,
# ex: "dict(norm_obs=False, norm_reward=True)"
if isinstance(self.normalize, str):
warnings.warn("Using deprecated string format for `normalize`")
self.normalize_kwargs = eval(self.normalize)
self.normalize = True

if isinstance(self.normalize, dict):
self.normalize_kwargs = self.normalize
self.normalize = True

# Use the same discount factor as for the algorithm
if "gamma" in hyperparams:
self.normalize_kwargs["gamma"] = hyperparams["gamma"]
Expand Down Expand Up @@ -380,13 +380,15 @@ def _preprocess_hyperparams(
# Convert to python object if needed
for kwargs_key in {"policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"}:
if kwargs_key in hyperparams.keys() and isinstance(hyperparams[kwargs_key], str):
warnings.warn("Loading policy/buffer kwargs from string is deprecated", DeprecationWarning)
hyperparams[kwargs_key] = eval(hyperparams[kwargs_key])

# Preprocess monitor kwargs
if "monitor_kwargs" in hyperparams.keys():
self.monitor_kwargs = hyperparams["monitor_kwargs"]
# Convert str to python code
if isinstance(self.monitor_kwargs, str):
warnings.warn("Loading monitor kwargs from string is deprecated", DeprecationWarning)
self.monitor_kwargs = eval(self.monitor_kwargs)
del hyperparams["monitor_kwargs"]

Expand Down
5 changes: 3 additions & 2 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def train():
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded"
"-conf", "--conf-file", type=str, default=None, help="Custom python file from which the hyperparameters will be loaded."
"We expect that it contains a dictionary called 'hyperparams' which contains a key for each environment."
)
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
parser.add_argument(
Expand Down Expand Up @@ -234,7 +235,7 @@ def train():
n_eval_envs=args.n_eval_envs,
no_optim_plots=args.no_optim_plots,
device=args.device,
yaml_file=args.yaml_file,
config_file=args.conf_file,
show_progress=args.progress,
)

Expand Down

0 comments on commit ff4ea0f

Please sign in to comment.