Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature merge] Better sync command #168

Merged
merged 10 commits into from Apr 27, 2022
127 changes: 113 additions & 14 deletions jishaku/features/management.py
Expand Up @@ -13,8 +13,10 @@

import itertools
import math
import re
import time
import traceback
import typing
from urllib.parse import urlencode

import discord
Expand All @@ -23,7 +25,7 @@
from jishaku.features.baseclass import Feature
from jishaku.flags import Flags
from jishaku.modules import ExtensionConverter
from jishaku.paginators import WrappedPaginator
from jishaku.repl import inspections


class ManagementFeature(Feature):
Expand All @@ -39,7 +41,7 @@ async def jsk_load(self, ctx: commands.Context, *extensions: ExtensionConverter)
Reports any extensions that failed to load.
"""

paginator = WrappedPaginator(prefix='', suffix='')
paginator = commands.Paginator(prefix='', suffix='')

# 'jsk reload' on its own just reloads jishaku
if ctx.invoked_with == 'reload' and not extensions:
Expand Down Expand Up @@ -75,7 +77,7 @@ async def jsk_unload(self, ctx: commands.Context, *extensions: ExtensionConverte
Reports any extensions that failed to unload.
"""

paginator = WrappedPaginator(prefix='', suffix='')
paginator = commands.Paginator(prefix='', suffix='')
icon = "\N{OUTBOX TRAY}"

for extension in itertools.chain(*extensions):
Expand Down Expand Up @@ -191,25 +193,122 @@ async def jsk_rtt(self, ctx: commands.Context):
if self.bot.latency > 0.0:
websocket_readings.append(self.bot.latency)

SLASH_COMMAND_ERROR = re.compile(r"In ((?:\d+\.[a-z]+\.?)+)")

@Feature.Command(parent="jsk", name="sync")
async def jsk_sync(self, ctx: commands.Context, *guild_ids: int):
async def jsk_sync(self, ctx: commands.Context, *targets: str):
"""
Sync global or guild application commands to Discord.
"""

paginator = WrappedPaginator(prefix='', suffix='')
paginator = commands.Paginator(prefix='', suffix='')

if not guild_ids:
synced = await self.bot.tree.sync()
paginator.add_line(f"\N{SATELLITE ANTENNA} Synced {len(synced)} global commands")
else:
for guild_id in guild_ids:
guilds = set()
for target in targets:
if target == '$':
guilds.add(None)
elif target == '*':
guilds |= set(self.bot.tree._guild_commands.keys()) # pylint: disable=protected-access
elif target == '.':
guilds.add(ctx.guild.id)
else:
try:
synced = await self.bot.tree.sync(guild=discord.Object(guild_id))
except discord.HTTPException as exc:
paginator.add_line(f"\N{WARNING SIGN} `{guild_id}`: {exc.text}")
guilds.add(int(target))
except ValueError as error:
raise commands.BadArgument(f"{target} is not a valid guild ID") from error

if not guilds:
guilds.add(None)

guilds: typing.List[typing.Optional[int]] = list(guilds)
guilds.sort(key=lambda g: (g is not None, g))

for guild in guilds:
slash_commands = self.bot.tree._get_all_commands( # pylint: disable=protected-access
guild=discord.Object(guild) if guild else None
)
payload = [command.to_dict() for command in slash_commands]

try:
if guild is None:
data = await self.bot.http.bulk_upsert_global_commands(self.bot.application_id, payload=payload)
else:
data = await self.bot.http.bulk_upsert_guild_commands(self.bot.application_id, guild, payload=payload)

synced = [
discord.app_commands.AppCommand(data=d, state=ctx._state) # pylint: disable=protected-access,no-member
for d in data
]

except discord.HTTPException as error:
# It's diagnosis time
error_text = []
for line in str(error).split("\n"):
error_text.append(line)

try:
match = self.SLASH_COMMAND_ERROR.match(line)
if not match:
continue

pool = slash_commands
selected_command = None
name = ""
parts = match.group(1).split('.')
assert len(parts) % 2 == 0

for part_index in range(0, len(parts), 2):
index = int(parts[part_index])
# prop = parts[part_index + 1]

if pool:
# If the pool exists, this should be a subcommand
selected_command = pool[index]
name += selected_command.name + " "

if hasattr(selected_command, '_children'):
pool = list(selected_command._children.values()) # pylint: disable=protected-access
else:
pool = None
else:
# Otherwise, the pool has been exhausted, and this likely is referring to a parameter
param = list(selected_command._params.keys())[index] # pylint: disable=protected-access
name += f"(parameter: {param}) "

if selected_command:
to_inspect = None

if hasattr(selected_command, 'callback'):
to_inspect = selected_command.callback
elif isinstance(selected_command, commands.Cog):
to_inspect = type(selected_command)

try:
error_text.append(''.join([
"\N{MAGNET} This is likely caused by: `",
name,
"` at ",
str(inspections.file_loc_inspection(to_inspect)),
":",
str(inspections.line_span_inspection(to_inspect))
]))
except Exception: # pylint: disable=broad-except
error_text.append(f"\N{MAGNET} This is likely caused by: `{name}`")

except Exception as diag_error: # pylint: disable=broad-except
error_text.append(f"\N{MAGNET} Couldn't determine cause: {type(diag_error).__name__}: {diag_error}")

error_text = '\n'.join(error_text)

if guild:
paginator.add_line(f"\N{WARNING SIGN} `{guild}`: {error_text}", empty=True)
else:
paginator.add_line(f"\N{WARNING SIGN} Global: {error_text}", empty=True)
else:
if guild:
paginator.add_line(f"\N{SATELLITE ANTENNA} `{guild}` Synced {len(synced)} guild commands", empty=True)
else:
paginator.add_line(f"\N{SATELLITE ANTENNA} `{guild_id}` Synced {len(synced)} guild commands")
paginator.add_line(f"\N{SATELLITE ANTENNA} Synced {len(synced)} global commands", empty=True)

for page in paginator.pages:
await ctx.send(page)