Skip to content

Commit

Permalink
Preserve CurrentThreadExecutor across create_task
Browse files Browse the repository at this point in the history
Fixes django#214.

Signed-off-by: Anders Kaseorg <andersk@mit.edu>
  • Loading branch information
andersk committed Mar 19, 2022
1 parent cde961b commit 3ae7bd4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
7 changes: 7 additions & 0 deletions asgiref/sync.py
@@ -1,3 +1,4 @@
import asyncio
import asyncio.coroutines
import contextvars
import functools
Expand Down Expand Up @@ -100,6 +101,9 @@ class AsyncToSync:
# Keeps track of which CurrentThreadExecutor to use. This uses an asgiref
# Local, not a threadlocal, so that tasks can work out what their parent used.
executors = Local()
loop_thread_executors: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = (
weakref.WeakKeyDictionary()
)

def __init__(self, awaitable, force_new_loop=False):
if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
Expand Down Expand Up @@ -175,6 +179,7 @@ def __call__(self, *args, **kwargs):
if not (self.main_event_loop and self.main_event_loop.is_running()):
# Make our own event loop - in a new thread - and run inside that.
loop = asyncio.new_event_loop()
self.loop_thread_executors[loop] = current_executor
loop_executor = ThreadPoolExecutor(max_workers=1)
loop_future = loop_executor.submit(
self._run_event_loop, loop, awaitable
Expand Down Expand Up @@ -378,6 +383,8 @@ async def __call__(self, *args, **kwargs):
# Create new thread executor in current context
executor = ThreadPoolExecutor(max_workers=1)
self.context_to_thread_executor[thread_sensitive_context] = executor
elif loop in AsyncToSync.loop_thread_executors:
executor = AsyncToSync.loop_thread_executors[loop]
elif self.deadlock_context and self.deadlock_context.get(False):
raise RuntimeError(
"Single thread executor already being used, would deadlock"
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sync.py
Expand Up @@ -397,15 +397,21 @@ def test_thread_sensitive_outside_sync():
@async_to_sync
async def middle():
await inner()
await asyncio.create_task(inner_task())

# Inner sync function
# Inner sync functions
@sync_to_async
def inner():
result["thread"] = threading.current_thread()

@sync_to_async
def inner_task():
result["thread2"] = threading.current_thread()

# Run it
middle()
assert result["thread"] == threading.current_thread()
assert result["thread2"] == threading.current_thread()


@pytest.mark.asyncio
Expand Down

0 comments on commit 3ae7bd4

Please sign in to comment.