Skip to content

Commit

Permalink
Add support for tuple input on MultiBinary space (openai#2023)
Browse files Browse the repository at this point in the history
* Add support for tuple input on MultiBinary space

* Change input of multibinary

* Remove check of length != 1; Add shape check
  • Loading branch information
lsylusiyao committed Sep 25, 2020
1 parent 70d351c commit e862e46
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions gym/spaces/multi_binary.py
Expand Up @@ -4,9 +4,9 @@

class MultiBinary(Space):
'''
An n-dimensional binary space.
An n-shape binary space.
The argument to MultiBinary defines n.
The argument to MultiBinary defines n, which could be a number or a `list` of numbers.
Example Usage:
Expand All @@ -16,18 +16,31 @@ class MultiBinary(Space):
array([0,1,0,1,0], dtype =int8)
>> self.observation_space = spaces.MultiBinary([3,2])
>> self.observation_space.sample()
array([[0, 0],
[0, 1],
[1, 1]], dtype=int8)
'''

def __init__(self, n):
self.n = n
super(MultiBinary, self).__init__((self.n,), np.int8)
if type(n) in [tuple, list, np.ndarray]:
input_n = n
else:
input_n = (n, )
super(MultiBinary, self).__init__(input_n, np.int8)

def sample(self):
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)

def contains(self, x):
if isinstance(x, list):
if isinstance(x, list) or isinstance(x, tuple):
x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape:
return False
return ((x==0) | (x==1)).all()

def to_jsonable(self, sample_n):
Expand Down

0 comments on commit e862e46

Please sign in to comment.