Skip to content

Commit

Permalink
Fix torch compute_advantages (#1209)
Browse files Browse the repository at this point in the history
Fix a bug that breaks compute_advantage() when rewards has a shape
of like (1, N). The expected shape of advantages is the same of that of
rewards, but the Tensor.seqeeze() will wrongly squeeze out the
the first dimension in such a case.
  • Loading branch information
naeioi committed Mar 11, 2020
1 parent 61fe8db commit 883ef47
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/garage/torch/algos/loss_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ def compute_advantages(discount, gae_lambda, max_path_length, baselines,
deltas = (rewards + discount * F.pad(baselines, (0, 1))[:, 1:] - baselines)
deltas = F.pad(deltas, (0, max_path_length - 1)).unsqueeze(0).unsqueeze(0)

advantages = F.conv2d(deltas, adv_filter, stride=1).squeeze()
advantages = F.conv2d(deltas, adv_filter, stride=1).reshape(rewards.shape)
return advantages
76 changes: 38 additions & 38 deletions tests/garage/torch/algos/test_loss_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
import pytest
import tensorflow as tf
import torch

import garage.tf.misc.tensor_utils as tf_utils
import garage.torch.algos.loss_function_utils as torch_loss_utils
from tests.fixtures import TfGraphTestCase

Expand All @@ -12,45 +10,47 @@ def stack(d, arr):
return np.repeat(np.expand_dims(arr, axis=0), repeats=d, axis=0)


ONES = np.ones((4, 6))
ZEROS = np.zeros((4, 6))
ARRANGE = stack(4, np.arange(6))
PI_DIGITS = stack(4, [3, 1, 4, 1, 5, 9])
E_DIGITS = stack(4, [2, 7, 1, 8, 2, 8])
FIBS = stack(4, [1, 1, 2, 3, 5, 8])
ONES = np.ones((6, ))
ZEROS = np.zeros((6, ))
ARRANGE = np.arange(6)
PI_DIGITS = np.array([3, 1, 4, 1, 5, 9])
FIBS = np.array([1, 1, 2, 3, 5, 8])


class TestLossFunctionUtils(TfGraphTestCase):
"""Test class for torch algo utility functions."""

# yapf: disable
@pytest.mark.parametrize('gae_lambda, rewards_val, baselines_val', [
(0.4, ONES, ZEROS),
(0.8, PI_DIGITS, ARRANGE),
(1.2, ONES, FIBS),
(1.7, E_DIGITS, PI_DIGITS),
@pytest.mark.parametrize('discount', [1, 0.95])
@pytest.mark.parametrize('num_trajs', [1, 5])
@pytest.mark.parametrize('gae_lambda', [0, 0.5, 1])
@pytest.mark.parametrize('rewards_traj, baselines_traj', [
(ONES, ZEROS),
(PI_DIGITS, ARRANGE),
(ONES, FIBS),
])
# yapf: enable
def test_compute_advantages(self, gae_lambda, rewards_val, baselines_val):
discount = 0.99
max_len = rewards_val.shape[-1]

torch_advs = torch_loss_utils.compute_advantages(
discount, gae_lambda, max_len, torch.Tensor(baselines_val),
torch.Tensor(rewards_val))

rewards = tf.compat.v1.placeholder(dtype=tf.float32,
name='reward',
shape=[None, None])
baselines = tf.compat.v1.placeholder(dtype=tf.float32,
name='baseline',
shape=[None, None])
adv = tf_utils.compute_advantages(discount, gae_lambda, max_len,
baselines, rewards)
tf_advs = self.sess.run(adv,
feed_dict={
rewards: rewards_val,
baselines: baselines_val,
})

assert np.allclose(torch_advs.numpy(),
tf_advs.reshape(torch_advs.shape),
atol=1e-5)
def test_compute_advantages(self, num_trajs, discount, gae_lambda,
rewards_traj, baselines_traj):
"""Test compute_advantage function."""

def get_advantage(discount, gae_lambda, rewards, baselines):
adv = torch.zeros(rewards.shape)
for i in range(rewards.shape[0]):
acc = 0
for j in range(rewards.shape[1]):
acc = acc * discount * gae_lambda
acc += rewards[i][-j - 1] - baselines[i][-j - 1]
acc += discount * baselines[i][-j] if j else 0
adv[i][-j - 1] = acc
return adv

length = len(rewards_traj)

rewards = torch.Tensor(stack(num_trajs, rewards_traj))
baselines = torch.Tensor(stack(num_trajs, baselines_traj))
expected_adv = get_advantage(discount, gae_lambda, rewards, baselines)
computed_adv = torch_loss_utils.compute_advantages(
discount, gae_lambda, length, baselines, rewards)

assert torch.allclose(expected_adv, computed_adv)

0 comments on commit 883ef47

Please sign in to comment.