-
Notifications
You must be signed in to change notification settings - Fork 97
/
mpi.py
106 lines (78 loc) · 3.15 KB
/
mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# This file is part of the pyMOR project (https://www.pymor.org).
# Copyright pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (https://opensource.org/licenses/BSD-2-Clause)
from itertools import chain
import os
from pymor.parallel.basic import WorkerPoolBase
from pymor.tools import mpi
from pymor.tools.random import get_seed_seq
class MPIPool(WorkerPoolBase):
"""|WorkerPool| based pyMOR's MPI :mod:`event loop <pymor.tools.mpi>`."""
def __init__(self):
super().__init__()
self.logger.info(f'Connected to {mpi.size} ranks')
self._payload = mpi.call(mpi.function_call_manage, _setup_worker)
self._apply(os.chdir, os.getcwd())
self._map(_setup_rng, [[[s] for s in get_seed_seq().spawn(mpi.size)]])
def __del__(self):
mpi.call(mpi.remove_object, self._payload)
def __len__(self):
return mpi.size
def _push_object(self, obj):
return mpi.call(mpi.function_call_manage, _push_object, obj)
def _apply(self, function, *args, **kwargs):
return mpi.call(mpi.function_call, _worker_call_function, function, *args, **kwargs)
def _apply_only(self, function, worker, *args, **kwargs):
payload = mpi.get_object(self._payload)
payload[0] = (function, args, kwargs)
try:
result = mpi.call(mpi.function_call, _single_worker_call_function, self._payload, worker)
finally:
payload[0] = None
return result
def _map(self, function, chunks, **kwargs):
payload = mpi.get_object(self._payload)
payload[0] = chunks
try:
result = mpi.call(mpi.function_call, _worker_map_function, self._payload, function, **kwargs)
finally:
payload[0] = None
return result
def _remove_object(self, remote_id):
mpi.call(mpi.remove_object, remote_id)
def _worker_call_function(function, *args, **kwargs):
result = function(*args, **kwargs)
return mpi.comm.gather(result, root=0)
def _single_worker_call_function(payload, worker):
if mpi.rank0:
if worker == 0:
function, args, kwargs = payload[0]
return mpi.function_call(function, *args, **kwargs)
else:
mpi.comm.send(payload[0], dest=worker)
return mpi.comm.recv(source=worker)
else:
if mpi.rank != worker:
return
(function, args, kwargs) = mpi.comm.recv(source=0)
retval = mpi.function_call(function, *args, **kwargs)
mpi.comm.send(retval, dest=0)
def _worker_map_function(payload, function, **kwargs):
if mpi.rank0:
args = list(zip(*payload[0]))
else:
args = None
args = zip(*mpi.comm.scatter(args, root=0))
result = [mpi.function_call(function, *a, **kwargs) for a in args]
result = mpi.comm.gather(result, root=0)
if mpi.rank0:
return list(chain(*result))
def _setup_worker():
return [None]
def _setup_rng(seed_seq):
# ensure that each worker starts with a different yet deterministically
# initialized rng
from pymor.tools.random import new_rng
new_rng(seed_seq).install()
def _push_object(obj):
return obj