Skip to content

Commit

Permalink
Add hub integration
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Oct 13, 2022
1 parent fa82356 commit 4074eee
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion cleanrl_utils/evals/dqn_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def evaluate(
QNetwork: torch.nn.Module,
device: torch.device,
epsilon: float = 0.05,
capture_video: bool = True,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, True, run_name)])
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
q_network = QNetwork(envs).to(device)
q_network.load_state_dict(torch.load(model_path))
q_network.eval()
Expand All @@ -37,3 +38,20 @@ def evaluate(
obs = next_obs

return episodic_returns

if __name__ == "__main__":
from huggingface_hub import hf_hub_download
from cleanrl.dqn import QNetwork, make_env

model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-dqn-seed1", filename="q_network.pth")
evaluate(
model_path,
make_env,
"CartPole-v1",
eval_episodes=10,
run_name=f"eval",
QNetwork=QNetwork,
device="cpu",
epsilon=0.05,
capture_video=False,
)

0 comments on commit 4074eee

Please sign in to comment.