diff --git a/asgiref/sync.py b/asgiref/sync.py index b71b3799..02534cab 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -1,3 +1,4 @@ +import asyncio import asyncio.coroutines import contextvars import functools @@ -175,6 +176,16 @@ def __call__(self, *args, **kwargs): if not (self.main_event_loop and self.main_event_loop.is_running()): # Make our own event loop - in a new thread - and run inside that. loop = asyncio.new_event_loop() + create_task = loop.create_task + + def new_create_task(coro): + async def new_coro(): + self.executors.current = current_executor + return await coro + + return create_task(new_coro()) + + loop.create_task = new_create_task loop_executor = ThreadPoolExecutor(max_workers=1) loop_future = loop_executor.submit( self._run_event_loop, loop, awaitable diff --git a/tests/test_sync.py b/tests/test_sync.py index 8f563d92..2642396b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -397,15 +397,22 @@ def test_thread_sensitive_outside_sync(): @async_to_sync async def middle(): await inner() + await asyncio.create_task(inner_task()) # Inner sync function @sync_to_async def inner(): result["thread"] = threading.current_thread() + # Inner sync function + @sync_to_async + def inner_task(): + result["thread2"] = threading.current_thread() + # Run it middle() assert result["thread"] == threading.current_thread() + assert result["thread2"] == threading.current_thread() @pytest.mark.asyncio