diff --git a/asgiref/sync.py b/asgiref/sync.py index b71b3799..04b9a6a1 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -1,3 +1,4 @@ +import asyncio import asyncio.coroutines import contextvars import functools @@ -7,7 +8,7 @@ import threading import warnings import weakref -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Executor, Future, ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, overload from .current_thread_executor import CurrentThreadExecutor @@ -100,6 +101,9 @@ class AsyncToSync: # Keeps track of which CurrentThreadExecutor to use. This uses an asgiref # Local, not a threadlocal, so that tasks can work out what their parent used. executors = Local() + loop_thread_executors: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, Executor]" = ( + weakref.WeakKeyDictionary() + ) def __init__(self, awaitable, force_new_loop=False): if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable): @@ -175,6 +179,7 @@ def __call__(self, *args, **kwargs): if not (self.main_event_loop and self.main_event_loop.is_running()): # Make our own event loop - in a new thread - and run inside that. loop = asyncio.new_event_loop() + self.loop_thread_executors[loop] = current_executor loop_executor = ThreadPoolExecutor(max_workers=1) loop_future = loop_executor.submit( self._run_event_loop, loop, awaitable @@ -378,6 +383,8 @@ async def __call__(self, *args, **kwargs): # Create new thread executor in current context executor = ThreadPoolExecutor(max_workers=1) self.context_to_thread_executor[thread_sensitive_context] = executor + elif loop in AsyncToSync.loop_thread_executors: + executor = AsyncToSync.loop_thread_executors[loop] elif self.deadlock_context and self.deadlock_context.get(False): raise RuntimeError( "Single thread executor already being used, would deadlock" diff --git a/tests/test_sync.py b/tests/test_sync.py index 8f563d92..2837423b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -397,15 +397,21 @@ def test_thread_sensitive_outside_sync(): @async_to_sync async def middle(): await inner() + await asyncio.create_task(inner_task()) - # Inner sync function + # Inner sync functions @sync_to_async def inner(): result["thread"] = threading.current_thread() + @sync_to_async + def inner_task(): + result["thread2"] = threading.current_thread() + # Run it middle() assert result["thread"] == threading.current_thread() + assert result["thread2"] == threading.current_thread() @pytest.mark.asyncio