Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tianshou-like JAX+PPO+Mujoco #355

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
Draft

Conversation

quangr
Copy link

@quangr quangr commented Jan 31, 2023

Description

Add tianshou-like JAX+PPO+Mujoco code, which is tested in Hopper-v3 and HalfCheetah-v3.

11 seed test

Hopper-v3 (Tianshou 1M:2609.3+-700.8 ; 3M:3127.7+-413.0)
my result:
Hopper-v3

HalfCheetah-v3 (Tianshou 1M:5783.9+-1244.0 ; 3M:7337.4+-1508.2)
my result:
HalfCheetah-v3

This implementation uses a customized EnvWrapper class to wrap environment. Different from traditional Gym-type wrap which has step and reset method. EnvWrapper requires three methods recv,send and reset, these methods need to be pure functions in order to be transformed in jax. The recv method will modify what env received after an action step, and the send method will modify the action send to env.

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • [x ] I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm variant.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jan 31, 2023

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add your feedback Feb 5, 2023 at 5:55AM (UTC)

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @quangr, thanks for this awesome contribution! Being able to use JAX+PPO+MuJoCo+EnvPool will be a game-changer for a lot of folks! This PR will also make #217 not necessary.

Some comments and thoughts:

  • Would you mind sharing your wandb username so I can add you to the openrlbenchmark entity? It would be great if you could contribute tracked experiments there, and we can use our CLI utility (https://github.com/openrlbenchmark/openrlbenchmark) to plot charts.
  • Could you share your huggingface username and help add saved models?
    • On a second thought there might not be a way to render mujoco images with Envpool. Don't worry about this yet.
    • We recently added the huggingface integration as follows
      if args.save_model:
      model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
      with open(model_path, "wb") as f:
      f.write(
      flax.serialization.to_bytes(
      [
      vars(args),
      [
      agent_state.params.network_params,
      agent_state.params.actor_params,
      agent_state.params.critic_params,
      ],
      ]
      )
      )
      print(f"model saved to {model_path}")
      from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
      episodic_returns = evaluate(
      model_path,
      make_env,
      args.env_id,
      eval_episodes=10,
      run_name=f"{run_name}-eval",
      Model=(Network, Actor, Critic),
      )
      for idx, episodic_return in enumerate(episodic_returns):
      writer.add_scalar("eval/episodic_return", episodic_return, idx)
      if args.upload_model:
      from cleanrl_utils.huggingface import push_to_hub
      repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
      repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
      push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
    • You can load the trained model by running python -m cleanrl_utils.enjoy --exp-name ppo_atari_envpool_xla_jax_scan --env-id Breakout-v5
    • image

cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
def compute_gae_once(carry, inp, gamma, gae_lambda):
advantages = carry
nextdone, nexttruncated, nextvalues, curvalues, reward = inp
nextnonterminal = (1.0 - nextdone) * (1.0 - nexttruncated)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gym's truncation / termination has been a mess, so it's confusing to handle value estimation on truncation correctly.

If I understood correctly, if you want to handle value estimation correctly, you should use nextnonterminal = (1.0 - nextdone) which is equivalent to nextnonterminal = (1.0 - next_terminated) under env_type="gymnasium",

If you don't (which is fine so that this implementation is consistent with other PPO variants), then the current implementation is fine :)

If you choose to handle it correctly, make sure to add a note to the implementation details section of the docs, since it's a deviation from the original implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. Are we currently using gymnasium? My last PR still uses the normal gym interface. This the version bump a dependency of any of these libs?

cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
Comment on lines 448 to 450
if args.rew_norm:
returns = (returns / jnp.sqrt(agent_state.ret_rms.var + 10e-8)).astype(jnp.float32)
agent_state = agent_state.replace(ret_rms=agent_state.ret_rms.update(returns.flatten()))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Going through the logic here I realize this reward normalization is quite different from the original implementation, which may or may not be fine but it's worth pointing out the difference.

The original implementation does a "forward discounted return" and normalizes the reward on a per-step basis.
image
See source1 and source2

Here what happens is that the returns are normalized only after the rollout phase, so the rewards themselves are not normalized.

It doesn't feel like these two approaches are equivalent. If you want to get to the bottom of this, it might be worth it to conduct an empirical study comparing this implementation and #348.

Both approaches seem fine to me, but we should document if we choose the current apporach.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed there is no observation normalization. Is this intended?

cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@51616 51616 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for a nice PR! Still, there are some unjustified changes which might cause performance difference vs other versions. In general, aim to reduce the number of code changes (see #328 (comment) for how to compare your pr with a reference implementation). For your case, it might be a good idea to compare to ppo_continuous_action.py and ppo_atari_envpool_xla_jax_scan.py. A good rule of thumb is to stay close to these references as much as possible.

cleanrl/ppo_mujoco_envpool_xla_jax.py Outdated Show resolved Hide resolved
@quangr
Copy link
Author

quangr commented Feb 1, 2023

你好@quangr,感谢您的贡献!对于很多人来说,能够使用 JAX+PPO+MuJoCo+EnvPool 将改变游戏规则!此 PR 也将使#217变得不必要。

一些评论和想法:

  • 您介意分享您的 wandb 用户名以便我将您添加到该openrlbenchmark实体吗?如果您可以在那里贡献跟踪实验,那就太好了,我们可以使用我们的 CLI 实用程序 ( https://github.com/openrlbenchmark/openrlbenchmark ) 来绘制图表。

Thank @51616 and @vwxyzjn for code reviewing❤️! My wandb username is quangr. I will check these comment and improve on my code soon.

@quangr
Copy link
Author

quangr commented Feb 1, 2023

I have submit new commits for most of the comments, and here are my answers to some other comments. If there is something missing, please help let me know

Why values[1:] * (1.0 - dones[1:])? Maybe this should be handled within the compute_gae_once

I'm trying to mask the done value because tianshou do so https://github.com/thu-ml/tianshou/blob/774d3d8e833a1b1c1ed320e0971ab125161f4264/tianshou/policy/base.py#L288.

you're right i'm putting it to compute_gae_once function

API change

The xla api provide by envpool is not a pure function. The handle passing to send function is just a fat pointer point to envpool class.

When we keep all state inside a handle tuple, if we reset the environment, the pointer remains unchange, Other parts (like new statistics state) also requries a reset state. So I think there must be a change towards envpool API.

In order to have a less confusing API, maybe we can remove handle from return values of envs.xla(). I can't think of way to make things consistent with envpool for now.

Observation Normalization and gym or gymnasium

wrappers = [
VectorEnvNormObs(),
VectorEnvClipAct(envs.action_space.low, envs.action_space.high),
]
envs = VectorEnvWrapper(envs, wrappers)

There is a Observation Normalization, it was implemented as wrapper. And I think the Observation Normalization is mandatory to acheving high score in mujoco env.

And in this Observation Normalization Wrapper I actually turn the gym api into gymnasium api:

def recv(self, ret):
next_obs, reward, next_done, info = ret
next_truncated = info["TimeLimit.truncated"]
obs_rms = self.obs_rms.update(next_obs)
return self.replace(obs_rms=obs_rms), (
obs_rms.norm(next_obs),
reward,
next_done,
next_truncated,
info,
)

This is because when I writing this code, I use envpool lastest(0.8.1) version, and it use gymnasium api.

reward normalization

This reward normalization is bizarre to me too, but this is how tianshou implemente it, and it really works.

HalfCheetah-v3

W B Chart 2023_1_31 14_47_35

Hooper-v3

W B Chart 2023_1_31 14_48_06

@quangr
Copy link
Author

quangr commented Feb 3, 2023

I have run experiment for Ant-v4 HalfCheetah-v4 Hopper-v4 Walker2d-v4 Swimmer-v4 Humanoid-v4 Reacher-v4 InvertedPendulum-v4 InvertedDoublePendulum-v4. Here is the report https://wandb.ai/openrlbenchmark/cleanrl/reports/MuJoCo-jax-EnvPool--VmlldzozNDczNDkz.

comparing the result to tianshou

I notice that tianshou use 10 envs to evaluate performance from reset state every epoch as their benchmark, I wonder if it's a problem for comparing.

comparing with ppo_continuous_action_8M

better : Ant-v4 HalfCheetah-v4
similar :Hopper-v4 Walker2d-v4 Humanoid-v4 Reacher-v4
worse :Swimmer-v4
In tianshou benchmark, their parameter also act poorly in Swimmer. So it's a sensonable result.

as for InvertedDoublePendulum-v4 InvertedPendulum-v4, every agents in my version reach 1000 score, which not happens in ppo_continuous_action_8M. But it start to decline afterward, and in tainshou's training data we can observe same decline curve: https://drive.google.com/drive/folders/1tQvgmsBbuLPNU3qo5thTBi03QzGXygXf
https://drive.google.com/drive/folders/1ns2cGnAn_39wqCItmhDZIxihLi8-DBei

@vwxyzjn
Copy link
Owner

vwxyzjn commented Feb 4, 2023

Thanks for running the results! They look great. Feel free to click resolve conversation as you resolve the PR comments and let me know when it's ready for another review.

Meanwhile, you might find the following tool handy

pip install openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --env-ids Ant-v4 InvertedDoublePendulum-v4 Reacher-v4  Hopper-v4 HalfCheetah-v4 Swimmer-v4  Humanoid-v4 InvertedPendulum-v4 Walker2d-v4 \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history --report

Which generates

ppo_continuous_action_envpool_xla_jax_scan
ppo_continuous_action_envpool_xla_jax_scan-time

and the wandb report:

https://wandb.ai/costa-huang/cleanRL/reports/Regression-Report-ppo_continuous_action_envpool_xla_jax_scan--VmlldzozNDgzMzM4

Couple of notes:

  • Would you mind running the experiments for HumanoidStandup-v4, Pusher-v4 as well?
  • FWIW, it is possible to get high scores in Humanoid-v4 as well.

See

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --filters '?we=openrlbenchmark&wpn=envpool-cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_envpool' \
    --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4  \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history

ppo_continuous_action_envpool_xla_jax_scan

The hyperparams used is in https://github.com/vwxyzjn/envpool-cleanrl/blob/880552a168c08af334b5e5d8868bfbe5ea881445/ppo_continuous_action_envpool.py#L40-L75, with no obs nor reward normalization.

@quangr
Copy link
Author

quangr commented Feb 5, 2023

Thanks for running the results! They look great. Feel free to click resolve conversation as you resolve the PR comments and let me know when it's ready for another review.

Meanwhile, you might find the following tool handy

pip install openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --env-ids Ant-v4 InvertedDoublePendulum-v4 Reacher-v4  Hopper-v4 HalfCheetah-v4 Swimmer-v4  Humanoid-v4 InvertedPendulum-v4 Walker2d-v4 \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history --report

Which generates

ppo_continuous_action_envpool_xla_jax_scan ppo_continuous_action_envpool_xla_jax_scan-time

and the wandb report:

https://wandb.ai/costa-huang/cleanRL/reports/Regression-Report-ppo_continuous_action_envpool_xla_jax_scan--VmlldzozNDgzMzM4

Couple of notes:

  • Would you mind running the experiments for HumanoidStandup-v4, Pusher-v4 as well?
  • FWIW, it is possible to get high scores in Humanoid-v4 as well.

See

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --filters '?we=openrlbenchmark&wpn=envpool-cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_envpool' \
    --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4  \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history

ppo_continuous_action_envpool_xla_jax_scan

The hyperparams used is in https://github.com/vwxyzjn/envpool-cleanrl/blob/880552a168c08af334b5e5d8868bfbe5ea881445/ppo_continuous_action_envpool.py#L40-L75, with no obs nor reward normalization.

Thanks for updating me. I'll be ready for the code review once the documentation is finished. I'm also happy to run the experiment you suggested.

@quangr
Copy link
Author

quangr commented Feb 5, 2023

I have document the questions you brought up and I am now ready for the code review. I would be happy to hear your feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants