Skip to content

Commit

Permalink
Allow custom spaces in VectorEnv (#2038)
Browse files Browse the repository at this point in the history
* Allow custom observation spaces in VectorEnv

* Replace np.copy by deepcopy in reset of SyncVectorEnv

* Add tests for VectorEnv with custom spaces

* Add tests for shared memory and batches of custom spaces

* Remove unused import in VectorEnv test

* Add warning note in the Space class for custom spaces
  • Loading branch information
tristandeleu committed Sep 21, 2020
1 parent 8cf2685 commit 58401db
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 27 deletions.
7 changes: 7 additions & 0 deletions gym/error.py
Expand Up @@ -165,3 +165,10 @@ class ClosedEnvironmentError(Exception):
Trying to call `reset`, or `step`, while the environment is closed.
"""
pass

class CustomSpaceError(Exception):
"""
The space is a custom gym.Space instance, and is not supported by
`AsyncVectorEnv` with `shared_memory=True`.
"""
pass
9 changes: 9 additions & 0 deletions gym/spaces/space.py
Expand Up @@ -5,6 +5,15 @@ class Space(object):
"""Defines the observation and action spaces, so you can write generic
code that applies to any Env. For example, you can choose a random
action.
WARNING - Custom observation & action spaces can inherit from the `Space`
class. However, most use-cases should be covered by the existing space
classes (e.g. `Box`, `Discrete`, etc...), and container classes (`Tuple` &
`Dict`). Note that parametrized probability distributions (through the
`sample()` method), and batching functions (in `gym.vector.VectorEnv`), are
only well-defined for instances of spaces provided in gym by default.
Moreover, some implementations of Reinforcement Learning algorithms might
not handle custom spaces properly. Use custom spaces with care.
"""
def __init__(self, shape=None, dtype=None):
import numpy as np # takes about 300-400ms to import, so we load lazily
Expand Down
23 changes: 16 additions & 7 deletions gym/vector/async_vector_env.py
Expand Up @@ -8,7 +8,7 @@
from gym import logger
from gym.vector.vector_env import VectorEnv
from gym.error import (AlreadyPendingCallError, NoAsyncCallError,
ClosedEnvironmentError)
ClosedEnvironmentError, CustomSpaceError)
from gym.vector.utils import (create_shared_memory, create_empty_array,
write_to_shared_memory, read_from_shared_memory,
concatenate, CloudpickleWrapper, clear_mpi_env_vars)
Expand Down Expand Up @@ -83,10 +83,18 @@ def __init__(self, env_fns, observation_space=None, action_space=None,
observation_space=observation_space, action_space=action_space)

if self.shared_memory:
_obs_buffer = create_shared_memory(self.single_observation_space,
n=self.num_envs, ctx=ctx)
self.observations = read_from_shared_memory(_obs_buffer,
self.single_observation_space, n=self.num_envs)
try:
_obs_buffer = create_shared_memory(self.single_observation_space,
n=self.num_envs, ctx=ctx)
self.observations = read_from_shared_memory(_obs_buffer,
self.single_observation_space, n=self.num_envs)
except CustomSpaceError:
raise ValueError('Using `shared_memory=True` in `AsyncVectorEnv` '
'is incompatible with non-standard Gym observation spaces '
'(i.e. custom spaces inheriting from `gym.Space`), and is '
'only compatible with default Gym spaces (e.g. `Box`, '
'`Tuple`, `Dict`) for batching. Set `shared_memory=False` '
'if you use custom observation spaces.')
else:
_obs_buffer = None
self.observations = create_empty_array(
Expand Down Expand Up @@ -171,7 +179,8 @@ def reset_wait(self, timeout=None):
self._state = AsyncState.DEFAULT

if not self.shared_memory:
concatenate(results, self.observations, self.single_observation_space)
self.observations = concatenate(results, self.observations,
self.single_observation_space)

return deepcopy(self.observations) if self.copy else self.observations

Expand Down Expand Up @@ -230,7 +239,7 @@ def step_wait(self, timeout=None):
observations_list, rewards, dones, infos = zip(*results)

if not self.shared_memory:
concatenate(observations_list, self.observations,
self.observations = concatenate(observations_list, self.observations,
self.single_observation_space)

return (deepcopy(self.observations) if self.copy else self.observations,
Expand Down
8 changes: 5 additions & 3 deletions gym/vector/sync_vector_env.py
Expand Up @@ -63,9 +63,10 @@ def reset_wait(self):
for env in self.envs:
observation = env.reset()
observations.append(observation)
concatenate(observations, self.observations, self.single_observation_space)
self.observations = concatenate(observations, self.observations,
self.single_observation_space)

return np.copy(self.observations) if self.copy else self.observations
return deepcopy(self.observations) if self.copy else self.observations

def step_async(self, actions):
self._actions = actions
Expand All @@ -78,7 +79,8 @@ def step_wait(self):
observation = env.reset()
observations.append(observation)
infos.append(info)
concatenate(observations, self.observations, self.single_observation_space)
self.observations = concatenate(observations, self.observations,
self.single_observation_space)

return (deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards), np.copy(self._dones), infos)
Expand Down
33 changes: 31 additions & 2 deletions gym/vector/tests/test_async_vector_env.py
Expand Up @@ -2,10 +2,11 @@
import numpy as np

from multiprocessing import TimeoutError
from gym.spaces import Box
from gym.spaces import Box, Tuple
from gym.error import (AlreadyPendingCallError, NoAsyncCallError,
ClosedEnvironmentError)
from gym.vector.tests.utils import make_env, make_slow_env
from gym.vector.tests.utils import (CustomSpace, make_env,
make_slow_env, make_custom_space_env)

from gym.vector.async_vector_env import AsyncVectorEnv

Expand Down Expand Up @@ -194,3 +195,31 @@ def test_check_observations_async_vector_env(shared_memory):
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.close(terminate=True)


def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
actions = ('action-2', 'action-3', 'action-5', 'action-7')
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()

assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)

assert isinstance(reset_observations, tuple)
assert reset_observations == ('reset', 'reset', 'reset', 'reset')

assert isinstance(step_observations, tuple)
assert step_observations == ('step(action-2)', 'step(action-3)',
'step(action-5)', 'step(action-7)')


def test_custom_space_async_vector_env_shared_memory():
env_fns = [make_custom_space_env(i) for i in range(4)]
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)
14 changes: 12 additions & 2 deletions gym/vector/tests/test_shared_memory.py
Expand Up @@ -6,9 +6,10 @@
from multiprocessing import Array, Process
from collections import OrderedDict

from gym.spaces import Tuple, Dict
from gym.spaces import Box, Tuple, Dict
from gym.error import CustomSpaceError
from gym.vector.utils.spaces import _BaseGymSpaces
from gym.vector.tests.utils import spaces
from gym.vector.tests.utils import spaces, custom_spaces

from gym.vector.utils.shared_memory import (create_shared_memory,
read_from_shared_memory, write_to_shared_memory)
Expand Down Expand Up @@ -60,6 +61,15 @@ def assert_nested_type(lhs, rhs, n):
assert_nested_type(shared_memory, expected_type, n=n)


@pytest.mark.parametrize('n', [1, 8])
@pytest.mark.parametrize('ctx', [None, 'fork', 'spawn'], ids=['default', 'fork', 'spawn'])
@pytest.mark.parametrize('space', custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space):
ctx = mp if (ctx is None) else mp.get_context(ctx)
with pytest.raises(CustomSpaceError):
shared_memory = create_shared_memory(space, n=n, ctx=ctx)


@pytest.mark.parametrize('space', spaces,
ids=[space.__class__.__name__ for space in spaces])
def test_write_to_shared_memory(space):
Expand Down
17 changes: 16 additions & 1 deletion gym/vector/tests/test_spaces.py
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from gym.spaces import Box, MultiDiscrete, Tuple, Dict
from gym.vector.tests.utils import spaces
from gym.vector.tests.utils import spaces, custom_spaces, CustomSpace

from gym.vector.utils.spaces import _BaseGymSpaces, batch_space

Expand Down Expand Up @@ -32,8 +32,23 @@
})
]

expected_custom_batch_spaces_4 = [
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Tuple((
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Box(low=0, high=255, shape=(4,), dtype=np.uint8)
))
]

@pytest.mark.parametrize('space,expected_batch_space_4', list(zip(spaces,
expected_batch_spaces_4)), ids=[space.__class__.__name__ for space in spaces])
def test_batch_space(space, expected_batch_space_4):
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4


@pytest.mark.parametrize('space,expected_batch_space_4', list(zip(custom_spaces,
expected_custom_batch_spaces_4)), ids=[space.__class__.__name__ for space in custom_spaces])
def test_batch_space_custom_space(space, expected_batch_space_4):
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
25 changes: 23 additions & 2 deletions gym/vector/tests/test_sync_vector_env.py
@@ -1,8 +1,8 @@
import pytest
import numpy as np

from gym.spaces import Box
from gym.vector.tests.utils import make_env
from gym.spaces import Box, Tuple
from gym.vector.tests.utils import CustomSpace, make_env, make_custom_space_env

from gym.vector.sync_vector_env import SyncVectorEnv

Expand Down Expand Up @@ -70,3 +70,24 @@ def test_check_observations_sync_vector_env():
with pytest.raises(RuntimeError):
env = SyncVectorEnv(env_fns)
env.close()


def test_custom_space_sync_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
reset_observations = env.reset()
actions = ('action-2', 'action-3', 'action-5', 'action-7')
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()

assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)

assert isinstance(reset_observations, tuple)
assert reset_observations == ('reset', 'reset', 'reset', 'reset')

assert isinstance(step_observations, tuple)
assert step_observations == ('step(action-2)', 'step(action-3)',
'step(action-5)', 'step(action-7)')
14 changes: 13 additions & 1 deletion gym/vector/tests/test_vector_env.py
@@ -1,10 +1,12 @@
import pytest
import numpy as np

from gym.vector.tests.utils import make_env
from gym.spaces import Tuple
from gym.vector.tests.utils import CustomSpace, make_env

from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.vector.vector_env import VectorEnv

@pytest.mark.parametrize('shared_memory', [True, False])
def test_vector_env_equal(shared_memory):
Expand Down Expand Up @@ -41,3 +43,13 @@ def test_vector_env_equal(shared_memory):
finally:
async_env.close()
sync_env.close()


def test_custom_space_vector_env():
env = VectorEnv(4, CustomSpace(), CustomSpace())

assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple)

assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)
31 changes: 31 additions & 0 deletions gym/vector/tests/utils.py
Expand Up @@ -47,6 +47,30 @@ def step(self, action):
reward, done = 0., False
return observation, reward, done, {}

class CustomSpace(gym.Space):
"""Minimal custom observation space."""
def __eq__(self, other):
return isinstance(other, CustomSpace)

custom_spaces = [
CustomSpace(),
Tuple((CustomSpace(), Box(low=0, high=255, shape=(), dtype=np.uint8)))
]

class CustomSpaceEnv(gym.Env):
def __init__(self):
super(CustomSpaceEnv, self).__init__()
self.observation_space = CustomSpace()
self.action_space = CustomSpace()

def reset(self):
return 'reset'

def step(self, action):
observation = 'step({0:s})'.format(action)
reward, done = 0., False
return observation, reward, done, {}

def make_env(env_name, seed):
def _make():
env = gym.make(env_name)
Expand All @@ -60,3 +84,10 @@ def _make():
env.seed(seed)
return env
return _make

def make_custom_space_env(seed):
def _make():
env = CustomSpaceEnv()
env.seed(seed)
return env
return _make
18 changes: 15 additions & 3 deletions gym/vector/utils/numpy_utils.py
@@ -1,6 +1,6 @@
import numpy as np

from gym.spaces import Tuple, Dict
from gym.spaces import Space, Tuple, Dict
from gym.vector.utils.spaces import _BaseGymSpaces
from collections import OrderedDict

Expand Down Expand Up @@ -42,8 +42,11 @@ def concatenate(items, out, space):
return concatenate_tuple(items, out, space)
elif isinstance(space, Dict):
return concatenate_dict(items, out, space)
elif isinstance(space, Space):
return concatenate_custom(items, out, space)
else:
raise NotImplementedError()
raise ValueError('Space of type `{0}` is not a valid `gym.Space` '
'instance.'.format(type(space)))

def concatenate_base(items, out, space):
return np.stack(items, axis=0, out=out)
Expand All @@ -56,6 +59,9 @@ def concatenate_dict(items, out, space):
return OrderedDict([(key, concatenate([item[key] for item in items],
out[key], subspace)) for (key, subspace) in space.spaces.items()])

def concatenate_custom(items, out, space):
return tuple(items)


def create_empty_array(space, n=1, fn=np.zeros):
"""Create an empty (possibly nested) numpy array.
Expand Down Expand Up @@ -96,8 +102,11 @@ def create_empty_array(space, n=1, fn=np.zeros):
return create_empty_array_tuple(space, n=n, fn=fn)
elif isinstance(space, Dict):
return create_empty_array_dict(space, n=n, fn=fn)
elif isinstance(space, Space):
return create_empty_array_custom(space, n=n, fn=fn)
else:
raise NotImplementedError()
raise ValueError('Space of type `{0}` is not a valid `gym.Space` '
'instance.'.format(type(space)))

def create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
Expand All @@ -110,3 +119,6 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
def create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()])

def create_empty_array_custom(space, n=1, fn=np.zeros):
return None

0 comments on commit 58401db

Please sign in to comment.