New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Gymnasium-compliant PPO script #320
Merged
Merged
Changes from 13 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
be6da79
Add Gymnasium and dependencies
dtch1997 d098870
Implement Gymnasium-compliant PPO script
dtch1997 aa9ac0c
Ensure pre-commit passes
dtch1997 cb08637
Fix CI, add a `gymnasium_support` folder
vwxyzjn 5818d7d
update lock files
vwxyzjn 51a7128
add dependencies
vwxyzjn 363b48e
update requirements.txt; fix pre-commit
vwxyzjn e3174ca
update poetry files
vwxyzjn 8da5f7b
Support dm control action spaces
vwxyzjn 0544fcb
add dm_control support
vwxyzjn d30f3cf
Enable num_envs>1
dtch1997 99f7789
Enable auto-install of torch based on CUDA version
dtch1997 cbd83f6
Fix pre-commit
dtch1997 8cf18e3
bump torch version
vwxyzjn fe81b99
bump wandb version
vwxyzjn dd80937
change key for mujoco_py installation
vwxyzjn c46f700
update CI
vwxyzjn 0d3a5e1
update docs
vwxyzjn fa73a60
Merge branch 'gymnasium_ppo' of https://github.com/vwxyzjn/cleanrl in…
vwxyzjn 0381d7a
downgrade torch
vwxyzjn 6582fab
update docs
vwxyzjn b3f19fd
update teset cases
vwxyzjn 1e904a3
set default env = HalfCheetah-v4
vwxyzjn 3cd9917
directly replace `ppo_continuous_action.py`
vwxyzjn 08f9744
deprecate pybullet dependency in ppo
vwxyzjn b81d207
remove pybullet test case
vwxyzjn de3f410
support video recording to wandb
vwxyzjn 1b01a4f
update docs
vwxyzjn 73c0caf
update depdency for test cases
vwxyzjn 3df239a
update test cases and add dm_control tests
vwxyzjn 8a9a467
update docs
vwxyzjn b56fe05
update mkdocs base
vwxyzjn 7660199
revert doc changes
vwxyzjn 8fd5657
fix dm_control test cases
vwxyzjn a140a3e
quick docs
vwxyzjn 708cfad
fix tests on CI
vwxyzjn 6a41003
fix test case
vwxyzjn 3efbc24
fix CI
vwxyzjn 0e8df6f
Fix CI
vwxyzjn 19d08d5
update mujoco dependency
vwxyzjn 55f2209
Fix CI
vwxyzjn 7ef728c
fix CI
vwxyzjn d2b0a79
remote unused seed
vwxyzjn File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy | ||
import argparse | ||
import os | ||
import random | ||
import time | ||
from distutils.util import strtobool | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.distributions.normal import Normal | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
def parse_args(): | ||
# fmt: off | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
help="the name of this experiment") | ||
parser.add_argument("--seed", type=int, default=1, | ||
help="seed of the experiment") | ||
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, `torch.backends.cudnn.deterministic=False`") | ||
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, cuda will be enabled by default") | ||
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="if toggled, this experiment will be tracked with Weights and Biases") | ||
parser.add_argument("--wandb-project-name", type=str, default="cleanRL", | ||
help="the wandb's project name") | ||
parser.add_argument("--wandb-entity", type=str, default=None, | ||
help="the entity (team) of wandb's project") | ||
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="whether to capture videos of the agent performances (check out `videos` folder)") | ||
|
||
# Algorithm specific arguments | ||
parser.add_argument("--env-id", type=str, default="HalfCheetahBulletEnv-v0", | ||
help="the id of the environment") | ||
parser.add_argument("--total-timesteps", type=int, default=1000000, | ||
help="total timesteps of the experiments") | ||
parser.add_argument("--learning-rate", type=float, default=3e-4, | ||
help="the learning rate of the optimizer") | ||
parser.add_argument("--num-envs", type=int, default=1, | ||
help="the number of parallel game environments") | ||
parser.add_argument("--num-steps", type=int, default=2048, | ||
help="the number of steps to run in each environment per policy rollout") | ||
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggle learning rate annealing for policy and value networks") | ||
parser.add_argument("--gamma", type=float, default=0.99, | ||
help="the discount factor gamma") | ||
parser.add_argument("--gae-lambda", type=float, default=0.95, | ||
help="the lambda for the general advantage estimation") | ||
parser.add_argument("--num-minibatches", type=int, default=32, | ||
help="the number of mini-batches") | ||
parser.add_argument("--update-epochs", type=int, default=10, | ||
help="the K epochs to update the policy") | ||
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggles advantages normalization") | ||
parser.add_argument("--clip-coef", type=float, default=0.2, | ||
help="the surrogate clipping coefficient") | ||
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") | ||
parser.add_argument("--ent-coef", type=float, default=0.0, | ||
help="coefficient of the entropy") | ||
parser.add_argument("--vf-coef", type=float, default=0.5, | ||
help="coefficient of the value function") | ||
parser.add_argument("--max-grad-norm", type=float, default=0.5, | ||
help="the maximum norm for the gradient clipping") | ||
parser.add_argument("--target-kl", type=float, default=None, | ||
help="the target KL divergence threshold") | ||
args = parser.parse_args() | ||
args.batch_size = int(args.num_envs * args.num_steps) | ||
args.minibatch_size = int(args.batch_size // args.num_minibatches) | ||
# fmt: on | ||
return args | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name, gamma): | ||
def thunk(): | ||
if capture_video: | ||
env = gym.make(env_id, render_mode="rgb_array") | ||
else: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
if idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
env = gym.wrappers.ClipAction(env) | ||
env = gym.wrappers.NormalizeObservation(env) | ||
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) | ||
env = gym.wrappers.NormalizeReward(env, gamma=gamma) | ||
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) | ||
return env | ||
|
||
return thunk | ||
|
||
|
||
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | ||
torch.nn.init.orthogonal_(layer.weight, std) | ||
torch.nn.init.constant_(layer.bias, bias_const) | ||
return layer | ||
|
||
|
||
class Agent(nn.Module): | ||
def __init__(self, envs): | ||
super().__init__() | ||
self.critic = nn.Sequential( | ||
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), | ||
nn.Tanh(), | ||
layer_init(nn.Linear(64, 64)), | ||
nn.Tanh(), | ||
layer_init(nn.Linear(64, 1), std=1.0), | ||
) | ||
self.actor_mean = nn.Sequential( | ||
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), | ||
nn.Tanh(), | ||
layer_init(nn.Linear(64, 64)), | ||
nn.Tanh(), | ||
layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), | ||
) | ||
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) | ||
|
||
def get_value(self, x): | ||
return self.critic(x) | ||
|
||
def get_action_and_value(self, x, action=None): | ||
action_mean = self.actor_mean(x) | ||
action_logstd = self.actor_logstd.expand_as(action_mean) | ||
action_std = torch.exp(action_logstd) | ||
probs = Normal(action_mean, action_std) | ||
if action is None: | ||
action = probs.sample() | ||
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv( | ||
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] | ||
) | ||
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" | ||
|
||
agent = Agent(envs).to(device) | ||
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) | ||
|
||
# ALGO Logic: Storage setup | ||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) | ||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) | ||
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
values = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
global_step = 0 | ||
start_time = time.time() | ||
next_obs, _ = envs.reset(seed=args.seed) | ||
next_obs = torch.Tensor(next_obs).to(device) | ||
next_done = torch.zeros(args.num_envs).to(device) | ||
num_updates = args.total_timesteps // args.batch_size | ||
|
||
for update in range(1, num_updates + 1): | ||
# Annealing the rate if instructed to do so. | ||
if args.anneal_lr: | ||
frac = 1.0 - (update - 1.0) / num_updates | ||
lrnow = frac * args.learning_rate | ||
optimizer.param_groups[0]["lr"] = lrnow | ||
|
||
for step in range(0, args.num_steps): | ||
global_step += 1 * args.num_envs | ||
obs[step] = next_obs | ||
dones[step] = next_done | ||
|
||
# ALGO LOGIC: action logic | ||
with torch.no_grad(): | ||
action, logprob, _, value = agent.get_action_and_value(next_obs) | ||
values[step] = value.flatten() | ||
actions[step] = action | ||
logprobs[step] = logprob | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) | ||
done = np.logical_or(terminated, truncated) | ||
rewards[step] = torch.tensor(reward).to(device).view(-1) | ||
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) | ||
|
||
# Only print when at least 1 env is done | ||
if "final_info" not in infos: | ||
continue | ||
|
||
for info in infos["final_info"]: | ||
# Skip the envs that are not done | ||
if info is None: | ||
continue | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
|
||
# bootstrap value if not done | ||
with torch.no_grad(): | ||
next_value = agent.get_value(next_obs).reshape(1, -1) | ||
advantages = torch.zeros_like(rewards).to(device) | ||
lastgaelam = 0 | ||
for t in reversed(range(args.num_steps)): | ||
if t == args.num_steps - 1: | ||
nextnonterminal = 1.0 - next_done | ||
nextvalues = next_value | ||
else: | ||
nextnonterminal = 1.0 - dones[t + 1] | ||
nextvalues = values[t + 1] | ||
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] | ||
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam | ||
returns = advantages + values | ||
|
||
# flatten the batch | ||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) | ||
b_logprobs = logprobs.reshape(-1) | ||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape) | ||
b_advantages = advantages.reshape(-1) | ||
b_returns = returns.reshape(-1) | ||
b_values = values.reshape(-1) | ||
|
||
# Optimizing the policy and value network | ||
b_inds = np.arange(args.batch_size) | ||
clipfracs = [] | ||
for epoch in range(args.update_epochs): | ||
np.random.shuffle(b_inds) | ||
for start in range(0, args.batch_size, args.minibatch_size): | ||
end = start + args.minibatch_size | ||
mb_inds = b_inds[start:end] | ||
|
||
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) | ||
logratio = newlogprob - b_logprobs[mb_inds] | ||
ratio = logratio.exp() | ||
|
||
with torch.no_grad(): | ||
# calculate approx_kl http://joschu.net/blog/kl-approx.html | ||
old_approx_kl = (-logratio).mean() | ||
approx_kl = ((ratio - 1) - logratio).mean() | ||
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] | ||
|
||
mb_advantages = b_advantages[mb_inds] | ||
if args.norm_adv: | ||
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) | ||
|
||
# Policy loss | ||
pg_loss1 = -mb_advantages * ratio | ||
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) | ||
pg_loss = torch.max(pg_loss1, pg_loss2).mean() | ||
|
||
# Value loss | ||
newvalue = newvalue.view(-1) | ||
if args.clip_vloss: | ||
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 | ||
v_clipped = b_values[mb_inds] + torch.clamp( | ||
newvalue - b_values[mb_inds], | ||
-args.clip_coef, | ||
args.clip_coef, | ||
) | ||
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 | ||
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) | ||
v_loss = 0.5 * v_loss_max.mean() | ||
else: | ||
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() | ||
|
||
entropy_loss = entropy.mean() | ||
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) | ||
optimizer.step() | ||
|
||
if args.target_kl is not None: | ||
if approx_kl > args.target_kl: | ||
break | ||
|
||
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() | ||
var_y = np.var(y_true) | ||
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) | ||
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) | ||
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) | ||
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) | ||
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) | ||
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) | ||
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) | ||
writer.add_scalar("losses/explained_variance", explained_var, global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I observed that in wandb reports, in the infinite environment, this code does not behave properly. This may be because of the Handling Time Limits. Refer to the latest gymnasium documentation and It may be possible to solve this problem by only using the terminated signal as the done signal in the training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is expected. The purpose of this integration is to support gymnasium while keeping most of the other stuff the same. There is an ongoing debate on how to handle truncated states in gymnasium's API. See sail-sg/envpool#194 (comment)