diff --git a/discord/commands/core.py b/discord/commands/core.py index dfc3924d99..14065bed9a 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -655,45 +655,49 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: if self.permissions and self.default_permission: self.default_permission = False - def _parse_options(self, params) -> List[Option]: - if list(params.items())[0][0] == "self": - temp = list(params.items()) - temp.pop(0) - params = dict(temp) + def _check_required_params(self, params): params = iter(params.items()) + required_params = ["self", "context"] if self.attached_to_group or self.cog else ["context"] + for p in required_params: + try: + next(params) + except StopIteration: + raise ClientException(f'Callback for {self.name} command is missing "{p}" parameter.') - # next we have the 'ctx' as the next parameter - try: - next(params) - except StopIteration: - raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') + return params + + def _parse_options(self, params, *, check_params: bool = True) -> List[Option]: + if check_params: + params = self._check_required_params(params) final_options = [] for p_name, p_obj in params: - option = p_obj.annotation if option == inspect.Parameter.empty: option = str if self._is_typing_union(option): if self._is_typing_optional(option): - option = Option(option.__args__[0], "No description provided", required=False) + option = Option(option.__args__[0], "No description provided", required=False) # type: ignore # union type else: - option = Option(option.__args__, "No description provided") + option = Option(option.__args__, "No description provided") # type: ignore # union type if not isinstance(option, Option): - option = Option(option, "No description provided") + if isinstance(p_obj.default, Option): # arg: type = Option(...) + p_obj.default.input_type = SlashCommandOptionType.from_datatype(option) + option = p_obj.default + else: # arg: Option(...) = default + option = Option(option, "No description provided") if option.default is None: - if p_obj.default == inspect.Parameter.empty: - option.default = None - else: + if not p_obj.default == inspect.Parameter.empty and not isinstance(p_obj.default, Option): option.default = p_obj.default option.required = False if option.name is None: option.name = p_name - option._parameter_name = p_name + if option.name != p_name or option._parameter_name is None: + option._parameter_name = p_name _validate_names(option) _validate_descriptions(option) @@ -703,25 +707,15 @@ def _parse_options(self, params) -> List[Option]: return final_options def _match_option_param_names(self, params, options): - if list(params.items())[0][0] == "self": - temp = list(params.items()) - temp.pop(0) - params = dict(temp) - params = iter(params.items()) + params = self._check_required_params(params) - # next we have the 'ctx' as the next parameter - try: - next(params) - except StopIteration: - raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') - - check_annotations = [ + check_annotations: List[Callable[[Option, Type], bool]] = [ lambda o, a: o.input_type == SlashCommandOptionType.string and o.converter is not None, # pass on converters lambda o, a: isinstance(o.input_type, SlashCommandOptionType), # pass on slash cmd option type enums lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types lambda o, a: self._is_typing_optional(a) and not o.required and o._raw_type in a.__args__, # optional - lambda o, a: inspect.isclass(a) and issubclass(a, o._raw_type), # 'normal' types + lambda o, a: isinstance(a, type) and issubclass(a, o._raw_type), # 'normal' types ] for o in options: _validate_names(o) @@ -732,15 +726,14 @@ def _match_option_param_names(self, params, options): raise ClientException(f"Too many arguments passed to the options kwarg.") p_obj = p_obj.annotation - if not any(c(o, p_obj) for c in check_annotations): + if not any(check(o, p_obj) for check in check_annotations): raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.") o._parameter_name = p_name left_out_params = OrderedDict() - left_out_params[""] = "" # bypass first iter (ctx) for k, v in params: left_out_params[k] = v - options.extend(self._parse_options(left_out_params)) + options.extend(self._parse_options(left_out_params, check_params=False)) return options @@ -752,6 +745,12 @@ def _is_typing_union(self, annotation): def _is_typing_optional(self, annotation): return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore + def _set_cog(self, cog): + prev = self.cog + super()._set_cog(cog) + if (prev is None and cog is not None) or (prev is not None and cog is None): + self.options = self._parse_options(self._get_signature_parameters()) # parse again to leave out self + @property def is_subcommand(self) -> bool: return self.parent is not None @@ -1162,7 +1161,7 @@ def _update_copy(self, kwargs: Dict[str, Any]): return self.copy() def _set_cog(self, cog): - self.cog = cog + super()._set_cog(cog) for subcommand in self.subcommands: subcommand._set_cog(cog) diff --git a/discord/commands/options.py b/discord/commands/options.py index 1cd0260e57..e60c53c440 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -80,12 +80,12 @@ async def hello( ---------- input_type: :class:`Any` The type of input that is expected for this option. - description: :class:`str` - The description of this option. - Must be 100 characters or fewer. name: :class:`str` The name of this option visible in the UI. Inherits from the variable name if not provided as a parameter. + description: Optional[:class:`str`] + The description of this option. + Must be 100 characters or fewer. choices: Optional[List[Union[:class:`Any`, :class:`OptionChoice`]]] The list of available choices for this option. Can be a list of values or :class:`OptionChoice` objects (which represent a name:value pair). @@ -115,10 +115,11 @@ async def hello( See `here `_ for a list of valid locales. """ - def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> None: + def __init__(self, input_type: Any = str, /, description: Optional[str] = None, **kwargs) -> None: self.name: Optional[str] = kwargs.pop("name", None) if self.name is not None: self.name = str(self.name) + self._parameter_name = self.name # default self.description = description or "No description provided" self.converter = None self._raw_type = input_type @@ -140,7 +141,10 @@ def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> Non else: if _type == SlashCommandOptionType.channel: if not isinstance(input_type, tuple): - input_type = (input_type,) + if hasattr(input_type, "__args__"): # Union + input_type = input_type.__args__ + else: + input_type = (input_type,) for i in input_type: if i.__name__ == "GuildChannel": continue diff --git a/discord/enums.py b/discord/enums.py index 4d9805e02e..7a9230de11 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -695,8 +695,9 @@ def from_datatype(cls, datatype): if issubclass(datatype, float): return cls.number - # TODO: Improve the error message - raise TypeError(f"Invalid class {datatype} used as an input type for an Option") + from .commands.context import ApplicationContext + if not issubclass(datatype, ApplicationContext): # TODO: prevent ctx being passed here in cog commands + raise TypeError(f"Invalid class {datatype} used as an input type for an Option") # TODO: Improve the error message class EmbeddedActivity(Enum):