Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition #1039

Merged
merged 14 commits into from Apr 17, 2022
21 changes: 16 additions & 5 deletions discord/commands/context.py
Expand Up @@ -223,11 +223,22 @@ def send_modal(self) -> Callable[..., Awaitable[Interaction]]:
@property
def respond(self) -> Callable[..., Awaitable[Union[Interaction, WebhookMessage]]]:
"""Callable[..., Union[:class:`~.Interaction`, :class:`~.Webhook`]]: Sends either a response
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
or a followup response depending on if the interaction has been responded to yet or not."""
if not self.interaction.response.is_done():
return self.interaction.response.send_message # self.response
else:
return self.followup.send # self.send_followup
or a followup response depending if the interaction has been responded to yet or not."""

# This technically can still be effected by the race condition. Solving that would include a breaking change
# But now it will raise InteractionResponded not the unexpected HTTP exception
# So we return a wrapper that retires when that exception got raised

async def wrapper(*args, **kwargs):
try:
if not self.interaction.response.is_done():
return await self.interaction.response.send_message(*args, **kwargs) # self.response
else:
return await self.followup.send(*args, **kwargs) # self.send_followup
except discord.errors.InteractionResponded:
return await self.followup.send(*args, **kwargs)

return wrapper
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

@property
def send_response(self) -> Callable[..., Awaitable[Interaction]]:
Expand Down
112 changes: 75 additions & 37 deletions discord/interactions.py
Expand Up @@ -27,7 +27,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Coroutine

from . import utils
from .channel import ChannelType, PartialMessageable
Expand Down Expand Up @@ -438,11 +439,13 @@ class InteractionResponse:
__slots__: Tuple[str, ...] = (
"_responded",
"_parent",
"_response_lock",
)

def __init__(self, parent: Interaction):
self._parent: Interaction = parent
self._responded: bool = False
self._response_lock = asyncio.Lock()

def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before.
Expand Down Expand Up @@ -491,12 +494,14 @@ async def defer(self, *, ephemeral: bool = False) -> None:

if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=defer_type,
data=data,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=defer_type,
data=data,
)
)
self._responded = True

Expand All @@ -520,11 +525,13 @@ async def pong(self) -> None:
parent = self._parent
if parent.type is InteractionType.ping:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.pong.value,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.pong.value,
)
)
self._responded = True

Expand Down Expand Up @@ -640,13 +647,15 @@ async def send_message(
parent = self._parent
adapter = async_context.get()
try:
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.channel_message.value,
data=payload,
files=files,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.channel_message.value,
data=payload,
files=files,
)
)
finally:
if files:
Expand Down Expand Up @@ -736,12 +745,14 @@ async def edit_message(
state.prevent_view_updates_for(message_id)
payload["components"] = [] if view is None else view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.message_update.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.message_update.value,
data=payload,
)
)

if view and not view.is_finished():
Expand Down Expand Up @@ -782,12 +793,14 @@ async def send_autocomplete_result(
payload = {"choices": [c.to_dict() for c in choices]}

adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.auto_complete_result.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.auto_complete_result.value,
data=payload,
)
)

self._responded = True
Expand All @@ -814,17 +827,42 @@ async def send_modal(self, modal: Modal) -> Interaction:

payload = modal.to_dict()
adapter = async_context.get()
await adapter.create_interaction_response(
self._parent.id,
self._parent.token,
session=self._parent._session,
type=InteractionResponseType.modal.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
self._parent.id,
self._parent.token,
session=self._parent._session,
type=InteractionResponseType.modal.value,
data=payload,
)
)
self._responded = True
self._parent._state.store_modal(modal, self._parent.user.id)
return self._parent

async def _locked_response(self, coro: Coroutine[Any]):
"""|coro|

Wraps a response and makes sure that is locked while executing
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

Parameters
-----------
coro:
The coroutine to wrap
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

Raises
-------
HTTPException
Deferring the interaction failed.
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
InteractionResponded
This interaction has already been responded to before.
"""
async with self._response_lock:
if self.is_done():
coro.close() # Cleanup unawaited coroutine
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
raise InteractionResponded(self._parent)
await coro


class _InteractionMessageState:
__slots__ = ("_parent", "_interaction")
Expand Down