Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncation not handled correctly when optimize_memory_usage=True #460

Open
samlobel opened this issue Apr 25, 2024 · 0 comments
Open

Truncation not handled correctly when optimize_memory_usage=True #460

samlobel opened this issue Apr 25, 2024 · 0 comments

Comments

@samlobel
Copy link

Problem Description

In, for example, dqn_atari.py the replay buffer is instantiated with the optimize_memory_usage=True flag. This makes the buffer only have one stored list for observations, and chooses next_obs=observations[i+1] when sampling. However, cleanrl does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]). But optimize_memory_usage means that this change is not reflected in the stored/sampled data.

Checklist

Current Behavior

Instead of data.next_observation[i] being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.

Expected Behavior

It should be the correct next observation.

Possible Solution

I'm guessing there's a way to make this work, but for now the easiest thing to do is set optimize_memory_usage to False.

Steps to Reproduce

Here's a minimal code example, where the important parts are directly cribbed from dqn_atari.py. Switching to optimize_memory_usage=False prevents the assertion error.

import gymnasium as gym

from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3 as sb3
import numpy as np


def make_env(env_id, seed, idx, capture_video, run_name):    
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env

    return thunk


envs = gym.vector.SyncVectorEnv(
    [make_env("MountainCar-v0", i, i, False, "testing") for i in [0]]
)

obs, _ = envs.reset(seed=0)

rb = ReplayBuffer(
    1000,
    envs.single_observation_space,
    envs.single_action_space,
    "cpu",
    optimize_memory_usage=True,
    # optimize_memory_usage=False,
    handle_timeout_termination=False,
)

seen_obs_and_next = set()
for i in range(1000):
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)
    # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
    real_next_obs = next_obs.copy()
    for idx, trunc in enumerate(truncations):
        if trunc:
            real_next_obs[idx] = infos["final_observation"][idx]

    rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
    for o, next_o in zip(obs, real_next_obs): # because vectorized env
        seen_obs_and_next.add( (tuple(o.tolist()), tuple(next_o.tolist())) )


data = rb.sample(10000)
for i in range(10000):
    o = data.observations[i]
    no = data.next_observations[i]
    assert (tuple(o.tolist()), tuple(no.tolist())) in seen_obs_and_next


@samlobel samlobel changed the title Another truncation bug Truncation not handled correctly when optimize_memory_usage=True Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant