-
Notifications
You must be signed in to change notification settings - Fork 28
/
half_cheetah_environment.py
49 lines (43 loc) · 1.56 KB
/
half_cheetah_environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
This code add event detectors to the Ant3 Environment
"""
import gym
import numpy as np
from gym.envs.mujoco.half_cheetah_v3 import HalfCheetahEnv
from reward_machines.rm_environment import RewardMachineEnv
class MyHalfCheetahEnv(gym.Wrapper):
def __init__(self):
# Note that the current position is key for our tasks
super().__init__(HalfCheetahEnv(exclude_current_positions_from_observation=False))
def step(self, action):
# executing the action in the environment
next_obs, original_reward, env_done, info = self.env.step(action)
self.info = info
return next_obs, original_reward, env_done, info
def get_events(self):
events = ''
if self.info['x_position'] < -10:
events+='b'
if self.info['x_position'] > 10:
events+='a'
if self.info['x_position'] < -2:
events+='d'
if self.info['x_position'] > 2:
events+='c'
if self.info['x_position'] > 4:
events+='e'
if self.info['x_position'] > 6:
events+='f'
if self.info['x_position'] > 8:
events+='g'
return events
class MyHalfCheetahEnvRM1(RewardMachineEnv):
def __init__(self):
env = MyHalfCheetahEnv()
rm_files = ["./envs/mujoco_rm/reward_machines/t1.txt"]
super().__init__(env, rm_files)
class MyHalfCheetahEnvRM2(RewardMachineEnv):
def __init__(self):
env = MyHalfCheetahEnv()
rm_files = ["./envs/mujoco_rm/reward_machines/t2.txt"]
super().__init__(env, rm_files)