Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition #1039

Merged
merged 14 commits into from Apr 17, 2022
17 changes: 9 additions & 8 deletions discord/commands/context.py
Expand Up @@ -220,14 +220,15 @@ def send_modal(self) -> Callable[..., Awaitable[Interaction]]:
"""Sends a modal dialog to the user who invoked the interaction."""
return self.interaction.response.send_modal

@property
def respond(self) -> Callable[..., Awaitable[Union[Interaction, WebhookMessage]]]:
"""Callable[..., Union[:class:`~.Interaction`, :class:`~.Webhook`]]: Sends either a response
or a followup response depending on if the interaction has been responded to yet or not."""
if not self.interaction.response.is_done():
return self.interaction.response.send_message # self.response
else:
return self.followup.send # self.send_followup
async def respond(self, *args, **kwargs) -> Union[Interaction, WebhookMessage]:
"""Sends either a response or a followup response depending if the interaction has been responded to yet or not."""
try:
if not self.interaction.response.is_done():
return await self.interaction.response.send_message(*args, **kwargs) # self.response
else:
return await self.followup.send(*args, **kwargs) # self.send_followup
except discord.errors.InteractionResponded:
return await self.followup.send(*args, **kwargs)

@property
def send_response(self) -> Callable[..., Awaitable[Interaction]]:
Expand Down
112 changes: 75 additions & 37 deletions discord/interactions.py
Expand Up @@ -27,7 +27,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Coroutine

from . import utils
from .channel import ChannelType, PartialMessageable
Expand Down Expand Up @@ -436,11 +437,13 @@ class InteractionResponse:
__slots__: Tuple[str, ...] = (
"_responded",
"_parent",
"_response_lock",
)

def __init__(self, parent: Interaction):
self._parent: Interaction = parent
self._responded: bool = False
self._response_lock = asyncio.Lock()

def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before.
Expand Down Expand Up @@ -489,12 +492,14 @@ async def defer(self, *, ephemeral: bool = False) -> None:

if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=defer_type,
data=data,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=defer_type,
data=data,
)
)
self._responded = True

Expand All @@ -518,11 +523,13 @@ async def pong(self) -> None:
parent = self._parent
if parent.type is InteractionType.ping:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.pong.value,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.pong.value,
)
)
self._responded = True

Expand Down Expand Up @@ -638,13 +645,15 @@ async def send_message(
parent = self._parent
adapter = async_context.get()
try:
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.channel_message.value,
data=payload,
files=files,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.channel_message.value,
data=payload,
files=files,
)
)
finally:
if files:
Expand Down Expand Up @@ -734,12 +743,14 @@ async def edit_message(
state.prevent_view_updates_for(message_id)
payload["components"] = [] if view is None else view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.message_update.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.message_update.value,
data=payload,
)
)

if view and not view.is_finished():
Expand Down Expand Up @@ -780,12 +791,14 @@ async def send_autocomplete_result(
payload = {"choices": [c.to_dict() for c in choices]}

adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.auto_complete_result.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.auto_complete_result.value,
data=payload,
)
)

self._responded = True
Expand All @@ -812,17 +825,42 @@ async def send_modal(self, modal: Modal) -> Interaction:

payload = modal.to_dict()
adapter = async_context.get()
await adapter.create_interaction_response(
self._parent.id,
self._parent.token,
session=self._parent._session,
type=InteractionResponseType.modal.value,
data=payload,
await self._locked_response(
adapter.create_interaction_response(
self._parent.id,
self._parent.token,
session=self._parent._session,
type=InteractionResponseType.modal.value,
data=payload,
)
)
self._responded = True
self._parent._state.store_modal(modal, self._parent.user.id)
return self._parent

async def _locked_response(self, coro: Coroutine[Any]):
"""|coro|

Wraps a response and makes sure that is locked while executing
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

Parameters
-----------
coro:
The coroutine to wrap
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

Raises
-------
HTTPException
Deferring the interaction failed.
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
InteractionResponded
This interaction has already been responded to before.
"""
async with self._response_lock:
if self.is_done():
coro.close() # Cleanup unawaited coroutine
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
raise InteractionResponded(self._parent)
await coro


class _InteractionMessageState:
__slots__ = ("_parent", "_interaction")
Expand Down