Skip to content

Commit

Permalink
vanilla_policy_gradient: using gymnasium and jaxtyping (don't use it)
Browse files Browse the repository at this point in the history
It should be used once the newest version of torchtyping works with
jupyter, because the error reporting in the old versions for jaxtyping
are bad (don't tell the shape).

See:
- agronholm/typeguard#364
- patrick-kidger/jaxtyping#6
  • Loading branch information
ddorn committed Aug 28, 2023
1 parent fc5e3eb commit aab201f
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions workshops/vanilla_policy_gradient/vanilla_policy_gradient.ipynb
@@ -0,0 +1,225 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Vanilla Policy Optimisation\n",
"\n",
"<a href=\"https://colab.research.google.com/github/EffiSciencesResearch/ML4G/blob/main/days/w1d5/vanilla_policy_gradient.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"Preliminary questions:\n",
"- Run the script with the defaults parameters on the terminal\n",
"- Explain from torch.distributions.categorical import Categorical\n",
"- google gym python, why is it useful?\n",
"- Policy gradient is model based or model free?\n",
"- Is policy gradient on-policy or off-policy?\n",
"\n",
"Read all the code, then:\n",
"- Complete the ... in the compute_loss function.\n",
"- Use https://github.com/google/jaxtyping to type the functions get_policy, get_action. You can draw inspiration from the compute_loss function.\n",
"- Answer the questions\n",
"\n",
"\n",
"Don't begin working on this algorithms if you don't understand the blog: https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html\n",
"\n",
"This exercise is short, but you should aim to understand everything in this code. Simply completing the types is not sufficient. The important thing here is to have a good understanding of each line of code, as well as the policy gradient theorem that we are using."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install jaxtyping typeguard==2.13.3 gym==0.25.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch import Tensor\n",
"from torch.distributions.categorical import Categorical\n",
"from torch.optim import Adam\n",
"import numpy as np\n",
"import gymnasium as gym\n",
"from gymnasium.spaces import Discrete, Box\n",
"\n",
"from jaxtyping import Float, Int, jaxtyped\n",
"from typeguard import typechecked\n",
"\n",
"\n",
"def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):\n",
" # Build a feedforward neural network.\n",
" layers = []\n",
" for j in range(len(sizes)-1):\n",
" act = activation if j < len(sizes)-2 else output_activation\n",
" layers += [nn.Linear(sizes[j], sizes[j+1]), act()]\n",
" \n",
" # What does * mean here? Search for unpacking in python\n",
" return nn.Sequential(*layers)\n",
"\n",
"def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, \n",
" epochs=50, batch_size=5000, render=False):\n",
"\n",
" # make environment, check spaces, get obs / act dims\n",
" env = gym.make(env_name)\n",
" assert isinstance(env.observation_space, Box), \\\n",
" \"This example only works for envs with continuous state spaces.\"\n",
" assert isinstance(env.action_space, Discrete), \\\n",
" \"This example only works for envs with discrete action spaces.\"\n",
"\n",
" obs_dim = env.observation_space.shape[0]\n",
" n_acts = env.action_space.n\n",
"\n",
" # Core of policy network\n",
" # What should be the sizes of the layers of the policy network?\n",
" logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])\n",
"\n",
" # make function to compute action distribution\n",
" @jaxtyped # Checks that the sizes are consistent between tensors\n",
" @typechecked # TODO: What is the shape of obs?\n",
" def get_policy(obs: Float[Tensor, \"batch obs_dim\"]) -> Categorical:\n",
" logits = logits_net(obs)\n",
" # Tip: Categorical is a convenient pytorch object which enable register logits (or a batch of logits)\n",
" # and then being able to sample from this pseudo-probability distribution with the \".sample()\" method.\n",
" return Categorical(logits=logits)\n",
"\n",
" # make action selection function (outputs int actions, sampled from policy)\n",
" # What is the shape of obs?\n",
" @jaxtyped\n",
" @typechecked # TODO: What is the shape of obs?\n",
" def get_action(obs: Float[Tensor, \"obs_dim\"]) -> int:\n",
" return get_policy(obs.unsqueeze(0)).sample().item()\n",
"\n",
" # make loss function whose gradient, for the right data, is policy gradient\n",
" @jaxtyped\n",
" @typechecked\n",
" def compute_loss(obs: Float[Tensor, \"batch obs_dim\"], acts: Int[Tensor, \"batch\"], rewards: Float[Tensor, \"batch\"]) -> Float[Tensor, \"\"]:\n",
" # TODO: \n",
" # rewards: a piecewise constant vector containing the total reward of each episode.\n",
" \n",
" # Use the get_policy function to get the categorical object, then sample from it with the 'log_prob' method.‹\n",
" log_probs = get_policy(obs).log_prob(acts)\n",
" return -(log_probs * rewards).mean()\n",
"\n",
"\n",
" # make optimizer\n",
" optimizer = Adam(logits_net.parameters(), lr=lr)\n",
"\n",
" # for training policy\n",
" def train_one_epoch():\n",
" # make some empty lists for logging.\n",
" batch_obs = [] # for observations\n",
" batch_acts = [] # for actions\n",
" batch_weights = [] # for R(tau) weighting in policy gradient\n",
" batch_rets = [] # for measuring episode returns # What is the return?\n",
" batch_lens = [] # for measuring episode lengths\n",
"\n",
" # reset episode-specific variables\n",
" obs, _ = env.reset() # first obs comes from starting distribution \n",
" ep_rews = [] # list for rewards accrued throughout ep\n",
"\n",
" # render first episode of each epoch\n",
" finished_rendering_this_epoch = False\n",
"\n",
" # collect experience by acting in the environment with current policy\n",
" while True:\n",
"\n",
" # rendering\n",
" if (not finished_rendering_this_epoch) and render:\n",
" env.render()\n",
"\n",
" # save obs\n",
" batch_obs.append(obs.copy())\n",
"\n",
" # act in the environment\n",
" act = get_action(torch.as_tensor(obs, dtype=torch.float32))\n",
" obs, rew, terminated, truncated, _ = env.step(act)\n",
"\n",
" # save action, reward\n",
" batch_acts.append(act)\n",
" ep_rews.append(rew)\n",
"\n",
" if terminated or truncated:\n",
" # if episode is over, record info about episode\n",
" # Is the reward discounted?\n",
" ep_ret, ep_len = sum(ep_rews), len(ep_rews)\n",
" batch_rets.append(ep_ret)\n",
" batch_lens.append(ep_len)\n",
"\n",
" # the weight for each logprob(a|s) is R(tau)\n",
" # Why do we use a constant vector here?\n",
" batch_weights += [ep_ret] * ep_len\n",
"\n",
" # reset episode-specific variables\n",
" obs, _ = env.reset()\n",
" ep_rews = []\n",
"\n",
" # won't render again this epoch\n",
" finished_rendering_this_epoch = True\n",
"\n",
" # end experience loop if we have enough of it\n",
" if len(batch_obs) > batch_size:\n",
" break\n",
"\n",
" # take a single policy gradient update step\n",
" optimizer.zero_grad()\n",
"\n",
" batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),\n",
" acts=torch.as_tensor(batch_acts, dtype=torch.int32),\n",
" rewards=torch.as_tensor(batch_weights, dtype=torch.float32)\n",
" )\n",
" batch_loss.backward()\n",
" optimizer.step()\n",
" return batch_loss, batch_rets, batch_lens\n",
"\n",
" # training loop\n",
" for i in range(epochs):\n",
" batch_loss, batch_rets, batch_lens = train_one_epoch()\n",
" print('epoch: %3d \\t loss: %.3f \\t return: %.3f \\t ep_len: %.3f'% \n",
" (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))\n",
"\n",
"\n",
"\n",
"train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, \n",
" epochs=50, batch_size=50, render=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Original algo here: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/vpg/vpg.py"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit aab201f

Please sign in to comment.