Skip to content

Commit

Permalink
Allow loading configuration from python package in addition to yaml f…
Browse files Browse the repository at this point in the history
…iles.
  • Loading branch information
ernestum committed Nov 24, 2022
1 parent 168c34e commit d5a53e5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
34 changes: 22 additions & 12 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import importlib
import os
import pickle as pkl
import time
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
n_eval_envs: int = 1,
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
yaml_file: Optional[str] = None,
config: Optional[str] = None,
show_progress: bool = False,
):
super().__init__()
Expand All @@ -108,7 +109,7 @@ def __init__(
# Take the root folder
default_path = Path(__file__).parent.parent

self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml")
self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
self.env_kwargs = {} if env_kwargs is None else env_kwargs
self.n_timesteps = n_timesteps
self.normalize = False
Expand Down Expand Up @@ -281,16 +282,21 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
print(f"Log path: {self.save_path}")

def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
print(f"Loading hyperparameters from: {self.yaml_file}")
with open(self.yaml_file) as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
print(f"Loading hyperparameters from: {self.config}")

if self.config.endswith(".yml"):
# Load hyperparameters from yaml file
with open(self.config) as f:
hyperparams_dict = yaml.safe_load(f)
else:
# Load hyperparameters from python package
hyperparams_dict = importlib.import_module(self.config).hyperparams
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}")

if self.custom_hyperparams is not None:
# Overwrite hyperparams if needed
Expand Down Expand Up @@ -336,6 +342,10 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
self.normalize_kwargs = eval(self.normalize)
self.normalize = True

if isinstance(self.normalize, dict):
self.normalize_kwargs = self.normalize
self.normalize = True

# Use the same discount factor as for the algorithm
if "gamma" in hyperparams:
self.normalize_kwargs["gamma"] = hyperparams["gamma"]
Expand Down
5 changes: 3 additions & 2 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def train():
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded"
"-conf", "--conf-file", type=str, default=None, help="Custom yaml file or python package from which the hyperparameters will be loaded."
"We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment."
)
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
parser.add_argument(
Expand Down Expand Up @@ -234,7 +235,7 @@ def train():
n_eval_envs=args.n_eval_envs,
no_optim_plots=args.no_optim_plots,
device=args.device,
yaml_file=args.yaml_file,
config=args.conf_file,
show_progress=args.progress,
)

Expand Down

0 comments on commit d5a53e5

Please sign in to comment.