/
test_shared_memory.py
142 lines (110 loc) · 4.95 KB
/
test_shared_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import pytest
import numpy as np
import multiprocessing as mp
from multiprocessing.sharedctypes import SynchronizedArray
from multiprocessing import Array, Process
from collections import OrderedDict
from gym.spaces import Tuple, Dict
from gym.vector.utils.spaces import _BaseGymSpaces
from gym.vector.tests.utils import spaces
from gym.vector.utils.shared_memory import (create_shared_memory,
read_from_shared_memory, write_to_shared_memory)
expected_types = [
Array('d', 1), Array('f', 1), Array('f', 3), Array('f', 4), Array('B', 1), Array('B', 32 * 32 * 3),
Array('i', 1), (Array('i', 1), Array('i', 1)), (Array('i', 1), Array('f', 2)),
Array('B', 3), Array('B', 19),
OrderedDict([
('position', Array('i', 1)),
('velocity', Array('f', 1))
]),
OrderedDict([
('position', OrderedDict([('x', Array('i', 1)), ('y', Array('i', 1))])),
('velocity', (Array('i', 1), Array('B', 1)))
])
]
@pytest.mark.parametrize('n', [1, 8])
@pytest.mark.parametrize('space,expected_type', list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize('ctx', [None, 'fork', 'spawn'],
ids=['default', 'fork', 'spawn'])
def test_create_shared_memory(space, expected_type, n, ctx):
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
if isinstance(lhs, (list, tuple)):
assert len(lhs) == len(rhs)
for lhs_, rhs_ in zip(lhs, rhs):
assert_nested_type(lhs_, rhs_, n)
elif isinstance(lhs, (dict, OrderedDict)):
assert set(lhs.keys()) ^ set(rhs.keys()) == set()
for key in lhs.keys():
assert_nested_type(lhs[key], rhs[key], n)
elif isinstance(lhs, SynchronizedArray):
# Assert the length of the array
assert len(lhs[:]) == n * len(rhs[:])
# Assert the data type
assert type(lhs[0]) == type(rhs[0])
else:
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
ctx = mp if (ctx is None) else mp.get_context(ctx)
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
assert_nested_type(shared_memory, expected_type, n=n)
@pytest.mark.parametrize('space', spaces,
ids=[space.__class__.__name__ for space in spaces])
def test_write_to_shared_memory(space):
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
if isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])
elif isinstance(lhs, (dict, OrderedDict)):
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])
elif isinstance(lhs, SynchronizedArray):
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
else:
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8,
samples[i])) for i in range(8)]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(shared_memory_n8, samples)
@pytest.mark.parametrize('space', spaces,
ids=[space.__class__.__name__ for space in spaces])
def test_read_from_shared_memory(space):
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs],
space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs],
space.spaces[key], n)
elif isinstance(space, _BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
assert lhs.shape == ((n,) + space.shape)
assert lhs.dtype == space.dtype
assert np.all(lhs == np.stack(rhs, axis=0))
else:
raise TypeError('Got unknown type `{0}`'.format(type(space)))
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)
shared_memory_n8 = create_shared_memory(space, n=8)
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8,
samples[i])) for i in range(8)]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(memory_view_n8, samples, space, n=8)