diff --git a/completions_test.go b/completions_test.go index df153fcf2..28cdcdf5a 100644 --- a/completions_test.go +++ b/completions_test.go @@ -828,13 +828,17 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { requiredFlag := rootCmd.Flags().Lookup("requiredFlag") rootCmd.PersistentFlags().IntP("requiredPersistent", "p", -1, "required persistent") - assertNoErr(t, rootCmd.MarkPersistentFlagRequired("requiredPersistent")) + rootCmd.PersistentFlags().Float64P("requiredPersistentFloat", "f", -1, "required persistent float") + assertNoErr(t, rootCmd.MarkPersistentFlagsRequired("requiredPersistent", "requiredPersistentFloat")) requiredPersistent := rootCmd.PersistentFlags().Lookup("requiredPersistent") + requiredPersistentFloat := rootCmd.PersistentFlags().Lookup("requiredPersistentFloat") rootCmd.Flags().StringP("release", "R", "", "Release name") - childCmd.Flags().BoolP("subRequired", "s", false, "sub required flag") - assertNoErr(t, childCmd.MarkFlagRequired("subRequired")) + childCmd.Flags().BoolP("subRequiredOne", "s", false, "first sub required flag") + childCmd.Flags().BoolP("subRequiredTwo", "z", false, "second sub required flag") + assertNoErr(t, childCmd.MarkFlagsRequired("subRequiredOne", "subRequiredTwo")) + childCmd.Flags().BoolP("subNotRequired", "n", false, "sub not required flag") // Test that a required flag is suggested even without the - prefix @@ -851,6 +855,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { "-r", "--requiredPersistent", "-p", + "--requiredPersistentFloat", + "-f", "realArg", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -870,6 +876,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { "-r", "--requiredPersistent", "-p", + "--requiredPersistentFloat", + "-f", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -901,8 +909,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { expected = strings.Join([]string{ "--requiredPersistent", "-p", - "--subRequired", + "--requiredPersistentFloat", + "-f", + "--subRequiredOne", "-s", + "--subRequiredTwo", + "-z", "subArg", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -919,8 +931,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { expected = strings.Join([]string{ "--requiredPersistent", "-p", - "--subRequired", + "--requiredPersistentFloat", + "-f", + "--subRequiredOne", "-s", + "--subRequiredTwo", + "-z", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -953,6 +969,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { expected = strings.Join([]string{ "--requiredPersistent", "-p", + "--requiredPersistentFloat", + "-f", "realArg", ":4", "Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n") @@ -962,12 +980,13 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { } // Test that when a persistent required flag is present, it is not suggested anymore - output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "") + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "") if err != nil { t.Errorf("Unexpected error: %v", err) } // Reset the flag for the next command requiredPersistent.Changed = false + requiredPersistentFloat.Changed = false expected = strings.Join([]string{ "childCmd", @@ -984,13 +1003,14 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) { } // Test that when all required flags are present, normal completion is done - output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "") + output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "") if err != nil { t.Errorf("Unexpected error: %v", err) } // Reset the flags for the next command requiredFlag.Changed = false requiredPersistent.Changed = false + requiredPersistentFloat.Changed = false expected = strings.Join([]string{ "realArg", diff --git a/shell_completions.go b/shell_completions.go index b035742d3..376c60d72 100644 --- a/shell_completions.go +++ b/shell_completions.go @@ -25,6 +25,13 @@ func (c *Command) MarkFlagRequired(name string) error { return MarkFlagRequired(c.Flags(), name) } +// MarkFlagsRequired instructs the various shell completion implementations to +// prioritize the named flags when performing completion, +// and causes your command to report an error if invoked without any of the flags. +func (c *Command) MarkFlagsRequired(names ...string) error { + return MarkFlagsRequired(c.Flags(), names...) +} + // MarkPersistentFlagRequired instructs the various shell completion implementations to // prioritize the named persistent flag when performing completion, // and causes your command to report an error if invoked without the flag. @@ -32,6 +39,13 @@ func (c *Command) MarkPersistentFlagRequired(name string) error { return MarkFlagRequired(c.PersistentFlags(), name) } +// MarkPersistentFlagsRequired instructs the various shell completion implementations to +// prioritize the named persistent flags when performing completion, +// and causes your command to report an error if invoked without any of the flags. +func (c *Command) MarkPersistentFlagsRequired(names ...string) error { + return MarkFlagsRequired(c.PersistentFlags(), names...) +} + // MarkFlagRequired instructs the various shell completion implementations to // prioritize the named flag when performing completion, // and causes your command to report an error if invoked without the flag. @@ -39,6 +53,18 @@ func MarkFlagRequired(flags *pflag.FlagSet, name string) error { return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) } +// MarkFlagsRequired instructs the various shell completion implementations to +// prioritize the named flags when performing completion, +// and causes your command to report an error if invoked without any of the flags. +func MarkFlagsRequired(flags *pflag.FlagSet, names ...string) error { + for _, name := range names { + if err := MarkFlagRequired(flags, name); err != nil { + return err + } + } + return nil +} + // MarkFlagFilename instructs the various shell completion implementations to // limit completions for the named flag to the specified file extensions. func (c *Command) MarkFlagFilename(name string, extensions ...string) error { diff --git a/site/content/user_guide.md b/site/content/user_guide.md index 3b42ef044..a1cf00a18 100644 --- a/site/content/user_guide.md +++ b/site/content/user_guide.md @@ -331,6 +331,24 @@ rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (re rootCmd.MarkPersistentFlagRequired("region") ``` +### Multiple Required flags + +If your command has multiple required flags that are not [grouped](#flag-groups) to report an error +when one or more flags have not been set, mark them as required: +```go +rootCmd.Flags().StringVarP(&Region, "region", "r", "", "AWS region (required)") +rootCmd.Flags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)") +rootCmd.MarkFlagsRequired("region", "failover") +``` + +Or, for multiple persistent flags: +```go +rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (required)") +rootCmd.PersistentFlags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)") +rootCmd.MarkPersistentFlagsRequired("region", "failover") +``` + + ### Flag Groups If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then