Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow python script for hyperparameter configuration #318

Merged
merged 12 commits into from
Nov 30, 2022
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## Release 1.7.0a2 (WIP)
## Release 1.7.0a3 (WIP)

### Breaking Changes
- `--yaml-file` argument was renamed to `-conf` (`--conf-file`) as now python file are supported too

### New Features
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
- Added ``monitor_kwargs`` parameter
- Handle the `env_kwargs` of `render:True` under the hood for panda-gym v1 envs in `enjoy` replay to match visualzation behavior of other envs
- Added support for python config file

### Bug fixes
- Allow `python -m rl_zoo3.cli` to be called directly
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py

# Run pytest and coverage report
pytest:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ You can use `-P` (`--progress`) option to display a progress bar.

Using a custom yaml file (which contains a `env_id` entry):
```
python train.py --algo algo_name --env env_id --yaml-file my_yaml.yml
python train.py --algo algo_name --env env_id --conf-file my_yaml.yml
```

For example (with tensorboard support):
Expand Down Expand Up @@ -139,7 +139,7 @@ Remark: plotting with the `--rliable` option is usually slow as confidence inter

## Custom Environment

The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--yaml-file` argument).
The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--conf-file` argument).

## Enjoy a Trained Agent

Expand Down
29 changes: 29 additions & 0 deletions hyperparams/python/ppo_config_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""This file just serves as an example on how to configure the zoo
using python scripts instead of yaml files."""
import torch

hyperparams = {
"MountainCarContinuous-v0": dict(
env_wrapper=[{"gym.wrappers.TimeLimit": {"max_episode_steps": 100}}],
normalize=True,
n_envs=1,
n_timesteps=20000.0,
policy="MlpPolicy",
batch_size=8,
n_steps=8,
gamma=0.9999,
learning_rate=7.77e-05,
ent_coef=0.00429,
clip_range=0.1,
n_epochs=2,
gae_lambda=0.9,
max_grad_norm=5,
vf_coef=0.19,
use_sde=True,
policy_kwargs=dict(
log_std_init=-3.29,
ortho_init=False,
activation_fn=torch.nn.ReLU,
),
)
}
41 changes: 29 additions & 12 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import importlib
import os
import pickle as pkl
import time
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
n_eval_envs: int = 1,
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
yaml_file: Optional[str] = None,
config: Optional[str] = None,
show_progress: bool = False,
):
super().__init__()
Expand All @@ -108,7 +109,7 @@ def __init__(
# Take the root folder
default_path = Path(__file__).parent.parent

self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml")
self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
self.env_kwargs = {} if env_kwargs is None else env_kwargs
self.n_timesteps = n_timesteps
self.normalize = False
Expand Down Expand Up @@ -281,16 +282,28 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
print(f"Log path: {self.save_path}")

def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
print(f"Loading hyperparameters from: {self.yaml_file}")
with open(self.yaml_file) as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
print(f"Loading hyperparameters from: {self.config}")

if self.config.endswith(".yml") or self.config.endswith(".yaml"):
# Load hyperparameters from yaml file
with open(self.config) as f:
hyperparams_dict = yaml.safe_load(f)
elif self.config.endswith(".py"):
global_variables = {}
# Load hyperparameters from python file
exec(Path(self.config).read_text(), global_variables)
hyperparams_dict = global_variables["hyperparams"]
else:
# Load hyperparameters from python package
hyperparams_dict = importlib.import_module(self.config).hyperparams
# raise ValueError(f"Unsupported config file format: {self.config}")

if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}")

if self.custom_hyperparams is not None:
# Overwrite hyperparams if needed
Expand Down Expand Up @@ -336,6 +349,10 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
self.normalize_kwargs = eval(self.normalize)
self.normalize = True

if isinstance(self.normalize, dict):
self.normalize_kwargs = self.normalize
self.normalize = True

# Use the same discount factor as for the algorithm
if "gamma" in hyperparams:
self.normalize_kwargs["gamma"] = hyperparams["gamma"]
Expand Down
21 changes: 19 additions & 2 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,19 @@ def train():
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded"
"-conf",
"--conf-file",
type=str,
default=None,
help="Custom yaml file or python package from which the hyperparameters will be loaded."
"We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.",
)
parser.add_argument(
"-yaml",
"--yaml-file",
type=str,
default=None,
help="This parameter is deprecated, please use `--conf-file` instead",
)
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
parser.add_argument(
Expand Down Expand Up @@ -150,6 +162,11 @@ def train():
env_id = args.env
registered_envs = set(gym.envs.registry.env_specs.keys()) # pytype: disable=module-attr

if args.yaml_file is not None:
raise ValueError(
"The`--yaml-file` parameter is deprecated and will be removed in RL Zoo3 v1.8, please use `--conf-file` instead",
)

# If the environment is not found, suggest the closest match
if env_id not in registered_envs:
try:
Expand Down Expand Up @@ -234,7 +251,7 @@ def train():
n_eval_envs=args.n_eval_envs,
no_optim_plots=args.no_optim_plots,
device=args.device,
yaml_file=args.yaml_file,
config=args.conf_file,
show_progress=args.progress,
)

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a2
1.7.0a3
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ markers =
inputs = .

[flake8]
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
# Ignore import not used when aliases are defined
per-file-ignores =
./rl_zoo3/__init__.py:F401
Expand Down
22 changes: 21 additions & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_custom_yaml(tmp_path):
"CartPole-v1",
"--log-folder",
tmp_path,
"-yaml",
"-conf",
"hyperparams/a2c.yml",
"-params",
"n_envs:2",
Expand All @@ -129,3 +129,23 @@ def test_custom_yaml(tmp_path):

return_code = subprocess.call(["python", "train.py"] + args)
_assert_eq(return_code, 0)


@pytest.mark.parametrize("config_file", ["hyperparams.python.ppo_config_example", "hyperparams/python/ppo_config_example.py"])
def test_python_config_file(tmp_path, config_file):
# Use the example python config file for training
args = [
"-n",
str(N_STEPS),
"--algo",
"ppo",
"--env",
"MountainCarContinuous-v0",
"--log-folder",
tmp_path,
"-conf",
config_file,
]

return_code = subprocess.call(["python", "train.py"] + args)
_assert_eq(return_code, 0)