From e862e46d7299366731ab59dcb6ec108b6e3922be Mon Sep 17 00:00:00 2001 From: lsylusiyao Date: Sat, 26 Sep 2020 06:43:15 +0800 Subject: [PATCH] Add support for tuple input on MultiBinary space (#2023) * Add support for tuple input on MultiBinary space * Change input of multibinary * Remove check of length != 1; Add shape check --- gym/spaces/multi_binary.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index f1876ee8611..e3b37890d27 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -4,9 +4,9 @@ class MultiBinary(Space): ''' - An n-dimensional binary space. + An n-shape binary space. - The argument to MultiBinary defines n. + The argument to MultiBinary defines n, which could be a number or a `list` of numbers. Example Usage: @@ -16,18 +16,31 @@ class MultiBinary(Space): array([0,1,0,1,0], dtype =int8) + >> self.observation_space = spaces.MultiBinary([3,2]) + + >> self.observation_space.sample() + + array([[0, 0], + [0, 1], + [1, 1]], dtype=int8) + ''' - def __init__(self, n): self.n = n - super(MultiBinary, self).__init__((self.n,), np.int8) + if type(n) in [tuple, list, np.ndarray]: + input_n = n + else: + input_n = (n, ) + super(MultiBinary, self).__init__(input_n, np.int8) def sample(self): return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype) def contains(self, x): - if isinstance(x, list): + if isinstance(x, list) or isinstance(x, tuple): x = np.array(x) # Promote list to array for contains check + if self.shape != x.shape: + return False return ((x==0) | (x==1)).all() def to_jsonable(self, sample_n):