Skip to content

Commit

Permalink
Fix task cancellation propagation to subtasks when using sync middlew…
Browse files Browse the repository at this point in the history
…are (#435)
  • Loading branch information
ttys0dev committed Jan 27, 2024
1 parent 088d6a5 commit 0503c2c
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 17 deletions.
59 changes: 45 additions & 14 deletions asgiref/sync.py
Expand Up @@ -203,6 +203,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# `main_wrap`.
context = [contextvars.copy_context()]

# Get task context so that parent task knows which task to propagate
# an asyncio.CancelledError to.
task_context = getattr(SyncToAsync.threadlocal, "task_context", None)

loop = None
# Use call_soon_threadsafe to schedule a synchronous callback on the
# main event loop's thread if it's there, otherwise make a new loop
Expand All @@ -211,6 +215,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
awaitable = self.main_wrap(
call_result,
sys.exc_info(),
task_context,
context,
*args,
**kwargs,
Expand Down Expand Up @@ -295,6 +300,7 @@ async def main_wrap(
self,
call_result: "Future[_R]",
exc_info: "OptExcInfo",
task_context: "Optional[List[asyncio.Task[Any]]]",
context: List[contextvars.Context],
*args: _P.args,
**kwargs: _P.kwargs,
Expand All @@ -309,6 +315,10 @@ async def main_wrap(
if context is not None:
_restore_context(context[0])

current_task = asyncio.current_task()
if current_task is not None and task_context is not None:
task_context.append(current_task)

try:
# If we have an exception, run the function inside the except block
# after raising it so exc_info is correctly populated.
Expand All @@ -324,6 +334,8 @@ async def main_wrap(
else:
call_result.set_result(result)
finally:
if current_task is not None and task_context is not None:
task_context.remove(current_task)
context[0] = contextvars.copy_context()


Expand Down Expand Up @@ -437,20 +449,38 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
context = contextvars.copy_context()
child = functools.partial(self.func, *args, **kwargs)
func = context.run

task_context: List[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
ret: _R
try:
# Run the code in the right thread
ret: _R = await loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)

ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
cancel_parent = True
try:
task = task_context[0]
task.cancel()
try:
await task
cancel_parent = False
except asyncio.CancelledError:
pass
except IndexError:
pass
if cancel_parent:
exec_coro.cancel()
ret = await exec_coro
finally:
_restore_context(context)
self.deadlock_context.set(False)
Expand All @@ -466,7 +496,7 @@ def __get__(
func = functools.partial(self.__call__, parent)
return functools.update_wrapper(func, self.func)

def thread_handler(self, loop, exc_info, func, *args, **kwargs):
def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs):
"""
Wraps the sync application with exception handling.
"""
Expand All @@ -476,6 +506,7 @@ def thread_handler(self, loop, exc_info, func, *args, **kwargs):
# Set the threadlocal for AsyncToSync
self.threadlocal.main_event_loop = loop
self.threadlocal.main_event_loop_pid = os.getpid()
self.threadlocal.task_context = task_context

# Run the function
# If we have an exception, run the function inside the except block
Expand Down
159 changes: 156 additions & 3 deletions tests/test_sync.py
Expand Up @@ -852,13 +852,10 @@ def sync_task():


@pytest.mark.asyncio
@pytest.mark.skip(reason="deadlocks")
async def test_inner_shield_sync_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync middleware.
Currently this tests is skipped as it causes a deadlock.
"""

# Hypothetical Django scenario - middleware function is sync
Expand Down Expand Up @@ -968,3 +965,159 @@ async def async_task():
assert task_complete

assert task_executed


@pytest.mark.asyncio
async def test_inner_shield_sync_and_async_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync and middleware chained
together.
"""

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_1():
async_to_sync(async_middleware_2)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_2():
await sync_to_async(sync_middleware_3)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_3():
async_to_sync(async_middleware_4)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_4():
await sync_to_async(sync_middleware_5)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_5():
async_to_sync(async_view)()

task_complete = False
task_cancel_caught = False

# Future that completes when subtask cancellation attempt is caught
task_blocker = asyncio.Future()

async def async_view():
"""Async view with a task that is shielded from cancellation."""
nonlocal task_complete, task_cancel_caught, task_blocker
task = asyncio.create_task(async_task())
try:
await asyncio.shield(task)
except asyncio.CancelledError:
task_cancel_caught = True
task_blocker.set_result(True)
await task
task_complete = True

task_executed = False

# Future that completes after subtask is created
task_started_future = asyncio.Future()

async def async_task():
"""Async subtask that should not be canceled when parent is canceled."""
nonlocal task_started_future, task_executed, task_blocker
task_started_future.set_result(True)
await task_blocker
task_executed = True

task_cancel_propagated = False

async with ThreadSensitiveContext():
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
await task_started_future
task.cancel()
try:
await task
except asyncio.CancelledError:
task_cancel_propagated = True
assert not task_cancel_propagated
assert task_cancel_caught
assert task_complete

assert task_executed


@pytest.mark.asyncio
async def test_inner_shield_sync_and_async_middleware_sync_task():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
cancelling a django request task when using sync and middleware chained
together with an async view calling a sync function calling an async task.
This test ensures that a parent initiated task cancellation will not
propagate to a shielded subtask.
"""

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_1():
async_to_sync(async_middleware_2)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_2():
await sync_to_async(sync_middleware_3)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_3():
async_to_sync(async_middleware_4)()

# Hypothetical Django scenario - middleware function is async
async def async_middleware_4():
await sync_to_async(sync_middleware_5)()

# Hypothetical Django scenario - middleware function is sync
def sync_middleware_5():
async_to_sync(async_view)()

task_complete = False
task_cancel_caught = False

# Future that completes when subtask cancellation attempt is caught
task_blocker = asyncio.Future()

async def async_view():
"""Async view with a task that is shielded from cancellation."""
nonlocal task_complete, task_cancel_caught, task_blocker
task = asyncio.create_task(sync_to_async(sync_parent)())
try:
await asyncio.shield(task)
except asyncio.CancelledError:
task_cancel_caught = True
task_blocker.set_result(True)
await task
task_complete = True

task_executed = False

# Future that completes after subtask is created
task_started_future = asyncio.Future()

def sync_parent():
async_to_sync(async_task)()

async def async_task():
"""Async subtask that should not be canceled when parent is canceled."""
nonlocal task_started_future, task_executed, task_blocker
task_started_future.set_result(True)
await task_blocker
task_executed = True

task_cancel_propagated = False

async with ThreadSensitiveContext():
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
await task_started_future
task.cancel()
try:
await task
except asyncio.CancelledError:
task_cancel_propagated = True
assert not task_cancel_propagated
assert task_cancel_caught
assert task_complete

assert task_executed

0 comments on commit 0503c2c

Please sign in to comment.