Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve CurrentThreadExecutor across create_task #320

Merged
merged 1 commit into from Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions asgiref/sync.py
@@ -1,3 +1,4 @@
import asyncio
import asyncio.coroutines
import contextvars
import functools
Expand Down Expand Up @@ -101,6 +102,10 @@ class AsyncToSync:
# Local, not a threadlocal, so that tasks can work out what their parent used.
executors = Local()

# When we can't find a CurrentThreadExecutor from the context, such as
# inside create_task, we'll look it up here from the running event loop.
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}

def __init__(self, awaitable, force_new_loop=False):
if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
# Python does not have very reliable detection of async functions
Expand Down Expand Up @@ -164,6 +169,7 @@ def __call__(self, *args, **kwargs):
old_current_executor = None
current_executor = CurrentThreadExecutor()
self.executors.current = current_executor
loop = None
# Use call_soon_threadsafe to schedule a synchronous callback on the
# main event loop's thread if it's there, otherwise make a new loop
# in this thread.
Expand All @@ -175,6 +181,7 @@ def __call__(self, *args, **kwargs):
if not (self.main_event_loop and self.main_event_loop.is_running()):
# Make our own event loop - in a new thread - and run inside that.
loop = asyncio.new_event_loop()
self.loop_thread_executors[loop] = current_executor
loop_executor = ThreadPoolExecutor(max_workers=1)
loop_future = loop_executor.submit(
self._run_event_loop, loop, awaitable
Expand All @@ -194,6 +201,8 @@ def __call__(self, *args, **kwargs):
current_executor.run_until_future(call_result)
finally:
# Clean up any executor we were running
if loop is not None:
del self.loop_thread_executors[loop]
if hasattr(self.executors, "current"):
del self.executors.current
if old_current_executor:
Expand Down Expand Up @@ -378,6 +387,9 @@ async def __call__(self, *args, **kwargs):
# Create new thread executor in current context
executor = ThreadPoolExecutor(max_workers=1)
self.context_to_thread_executor[thread_sensitive_context] = executor
elif loop in AsyncToSync.loop_thread_executors:
# Re-use thread executor for running loop
executor = AsyncToSync.loop_thread_executors[loop]
elif self.deadlock_context and self.deadlock_context.get(False):
raise RuntimeError(
"Single thread executor already being used, would deadlock"
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sync.py
Expand Up @@ -397,15 +397,21 @@ def test_thread_sensitive_outside_sync():
@async_to_sync
async def middle():
await inner()
await asyncio.create_task(inner_task())

# Inner sync function
# Inner sync functions
@sync_to_async
def inner():
result["thread"] = threading.current_thread()

@sync_to_async
def inner_task():
result["thread2"] = threading.current_thread()

# Run it
middle()
assert result["thread"] == threading.current_thread()
assert result["thread2"] == threading.current_thread()


@pytest.mark.asyncio
Expand Down