forked from celery/celery
/
asynchronous.py
315 lines (254 loc) · 9.42 KB
/
asynchronous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""Async I/O backend support utilities."""
import socket
import threading
import time
from collections import deque
from queue import Empty
from time import sleep
from weakref import WeakKeyDictionary, WeakSet
from kombu.utils.compat import detect_environment
from celery import states
from celery.exceptions import TimeoutError
from celery.utils.threads import THREAD_TIMEOUT_MAX
__all__ = (
'AsyncBackendMixin', 'BaseResultConsumer', 'Drainer',
'register_drainer',
)
drainers = {}
def register_drainer(name):
"""Decorator used to register a new result drainer type."""
def _inner(cls):
drainers[name] = cls
return cls
return _inner
@register_drainer('default')
class Drainer:
"""Result draining service."""
def __init__(self, result_consumer):
self.result_consumer = result_consumer
def start(self):
pass
def stop(self):
pass
def drain_events_until(self, p, timeout=None, interval=1, on_interval=None, wait=None):
wait = wait or self.result_consumer.drain_events
time_start = time.monotonic()
while 1:
# Total time spent may exceed a single call to wait()
if timeout and time.monotonic() - time_start >= timeout:
raise socket.timeout()
try:
yield self.wait_for(p, wait, timeout=interval)
except socket.timeout:
pass
if on_interval:
on_interval()
if p.ready: # got event on the wanted channel.
break
def wait_for(self, p, wait, timeout=None):
wait(timeout=timeout)
class greenletDrainer(Drainer):
spawn = None
_g = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._started = threading.Event()
self._stopped = threading.Event()
self._shutdown = threading.Event()
def run(self):
self._started.set()
while not self._stopped.is_set():
try:
self.result_consumer.drain_events(timeout=1)
except socket.timeout:
pass
self._shutdown.set()
def start(self):
if not self._started.is_set():
self._g = self.spawn(self.run)
self._started.wait()
def stop(self):
self._stopped.set()
self._shutdown.wait(THREAD_TIMEOUT_MAX)
@register_drainer('eventlet')
class eventletDrainer(greenletDrainer):
def spawn(self, func):
from eventlet import sleep, spawn
g = spawn(func)
sleep(0)
return g
def wait_for(self, p, wait, timeout=None):
self.start()
if not p.ready:
self._g._exit_event.wait(timeout=timeout)
@register_drainer('gevent')
class geventDrainer(greenletDrainer):
def spawn(self, func):
import gevent
g = gevent.spawn(func)
gevent.sleep(0)
return g
def wait_for(self, p, wait, timeout=None):
import gevent
self.start()
if not p.ready:
gevent.wait([self._g], timeout=timeout)
class AsyncBackendMixin:
"""Mixin for backends that enables the async API."""
def _collect_into(self, result, bucket):
self.result_consumer.buckets[result] = bucket
def iter_native(self, result, no_ack=True, **kwargs):
self._ensure_not_eager()
results = result.results
if not results:
raise StopIteration()
# we tell the result consumer to put consumed results
# into these buckets.
bucket = deque()
for node in results:
if not hasattr(node, '_cache'):
bucket.append(node)
elif node._cache:
bucket.append(node)
else:
self._collect_into(node, bucket)
for _ in self._wait_for_pending(result, no_ack=no_ack, **kwargs):
while bucket:
node = bucket.popleft()
if not hasattr(node, '_cache'):
yield node.id, node.children
else:
yield node.id, node._cache
while bucket:
node = bucket.popleft()
yield node.id, node._cache
def add_pending_result(self, result, weak=False, start_drainer=True):
if start_drainer:
self.result_consumer.drainer.start()
try:
self._maybe_resolve_from_buffer(result)
except Empty:
self._add_pending_result(result.id, result, weak=weak)
return result
def _maybe_resolve_from_buffer(self, result):
result._maybe_set_cache(self._pending_messages.take(result.id))
def _add_pending_result(self, task_id, result, weak=False):
concrete, weak_ = self._pending_results
if task_id not in weak_ and result.id not in concrete:
ref = (weak_ if weak else concrete)
results = ref.get(task_id, WeakSet() if weak else set())
results.add(result)
ref[task_id] = results
self.result_consumer.consume_from(task_id)
def add_pending_results(self, results, weak=False):
self.result_consumer.drainer.start()
return [self.add_pending_result(result, weak=weak, start_drainer=False)
for result in results]
def remove_pending_result(self, result):
self._remove_pending_result(result.id)
self.on_result_fulfilled(result)
return result
def _remove_pending_result(self, task_id):
for mapping in self._pending_results:
mapping.pop(task_id, None)
def on_result_fulfilled(self, result):
self.result_consumer.cancel_for(result.id)
def wait_for_pending(self, result,
callback=None, propagate=True, **kwargs):
self._ensure_not_eager()
for _ in self._wait_for_pending(result, **kwargs):
pass
return result.maybe_throw(callback=callback, propagate=propagate)
def _wait_for_pending(self, result,
timeout=None, on_interval=None, on_message=None,
**kwargs):
return self.result_consumer._wait_for_pending(
result, timeout=timeout,
on_interval=on_interval, on_message=on_message,
**kwargs
)
@property
def is_async(self):
return True
class BaseResultConsumer:
"""Manager responsible for consuming result messages."""
def __init__(self, backend, app, accept,
pending_results, pending_messages):
self.backend = backend
self.app = app
self.accept = accept
self._pending_results = pending_results
self._pending_messages = pending_messages
self.on_message = None
self.buckets = WeakKeyDictionary()
self.drainer = drainers[detect_environment()](self)
def start(self, initial_task_id, **kwargs):
raise NotImplementedError()
def stop(self):
pass
def drain_events(self, timeout=None):
raise NotImplementedError()
def consume_from(self, task_id):
raise NotImplementedError()
def cancel_for(self, task_id):
raise NotImplementedError()
def _after_fork(self):
self.buckets.clear()
self.buckets = WeakKeyDictionary()
self.on_message = None
self.on_after_fork()
def on_after_fork(self):
pass
def drain_events_until(self, p, timeout=None, on_interval=None):
return self.drainer.drain_events_until(
p, timeout=timeout, on_interval=on_interval)
def _wait_for_pending(self, result,
timeout=None, on_interval=None, on_message=None,
**kwargs):
self.on_wait_for_pending(result, timeout=timeout, **kwargs)
prev_on_m, self.on_message = self.on_message, on_message
try:
for _ in self.drain_events_until(
result.on_ready, timeout=timeout,
on_interval=on_interval):
yield
sleep(0)
except socket.timeout:
raise TimeoutError('The operation timed out.')
finally:
self.on_message = prev_on_m
def on_wait_for_pending(self, result, timeout=None, **kwargs):
pass
def on_out_of_band_result(self, message):
self.on_state_change(message.payload, message)
def _get_pending_result(self, task_id):
for mapping in self._pending_results:
try:
return mapping[task_id]
except KeyError:
pass
raise KeyError(task_id)
def on_state_change(self, meta, message):
if self.on_message:
self.on_message(meta)
if meta['status'] in states.READY_STATES:
task_id = meta['task_id']
try:
results = self._get_pending_result(task_id)
except KeyError:
# send to buffer in case we received this result
# before it was added to _pending_results.
self._pending_messages.put(task_id, meta)
else:
for result in results:
result._maybe_set_cache(meta)
buckets = self.buckets
try:
# remove bucket for this result, since it's fulfilled
bucket = buckets.pop(result)
except KeyError:
pass
else:
# send to waiter via bucket
bucket.append(result)
sleep(0)