Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Multiple Required/Persistent Flags #2110

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 27 additions & 7 deletions completions_test.go
Expand Up @@ -828,13 +828,17 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
requiredFlag := rootCmd.Flags().Lookup("requiredFlag")

rootCmd.PersistentFlags().IntP("requiredPersistent", "p", -1, "required persistent")
assertNoErr(t, rootCmd.MarkPersistentFlagRequired("requiredPersistent"))
rootCmd.PersistentFlags().Float64P("requiredPersistentFloat", "f", -1, "required persistent float")
assertNoErr(t, rootCmd.MarkPersistentFlagsRequired("requiredPersistent", "requiredPersistentFloat"))
requiredPersistent := rootCmd.PersistentFlags().Lookup("requiredPersistent")
requiredPersistentFloat := rootCmd.PersistentFlags().Lookup("requiredPersistentFloat")

rootCmd.Flags().StringP("release", "R", "", "Release name")

childCmd.Flags().BoolP("subRequired", "s", false, "sub required flag")
assertNoErr(t, childCmd.MarkFlagRequired("subRequired"))
childCmd.Flags().BoolP("subRequiredOne", "s", false, "first sub required flag")
childCmd.Flags().BoolP("subRequiredTwo", "z", false, "second sub required flag")
assertNoErr(t, childCmd.MarkFlagsRequired("subRequiredOne", "subRequiredTwo"))

childCmd.Flags().BoolP("subNotRequired", "n", false, "sub not required flag")

// Test that a required flag is suggested even without the - prefix
Expand All @@ -851,6 +855,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
"-r",
"--requiredPersistent",
"-p",
"--requiredPersistentFloat",
"-f",
"realArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
Expand All @@ -870,6 +876,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
"-r",
"--requiredPersistent",
"-p",
"--requiredPersistentFloat",
"-f",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

Expand Down Expand Up @@ -901,8 +909,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{
"--requiredPersistent",
"-p",
"--subRequired",
"--requiredPersistentFloat",
"-f",
"--subRequiredOne",
"-s",
"--subRequiredTwo",
"-z",
"subArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
Expand All @@ -919,8 +931,12 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{
"--requiredPersistent",
"-p",
"--subRequired",
"--requiredPersistentFloat",
"-f",
"--subRequiredOne",
"-s",
"--subRequiredTwo",
"-z",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")

Expand Down Expand Up @@ -953,6 +969,8 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
expected = strings.Join([]string{
"--requiredPersistent",
"-p",
"--requiredPersistentFloat",
"-f",
"realArg",
":4",
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n")
Expand All @@ -962,12 +980,13 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
}

// Test that when a persistent required flag is present, it is not suggested anymore
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "")
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Reset the flag for the next command
requiredPersistent.Changed = false
requiredPersistentFloat.Changed = false

expected = strings.Join([]string{
"childCmd",
Expand All @@ -984,13 +1003,14 @@ func TestRequiredFlagNameCompletionInGo(t *testing.T) {
}

// Test that when all required flags are present, normal completion is done
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "")
output, err = executeCommand(rootCmd, ShellCompNoDescRequestCmd, "--requiredFlag", "1", "--requiredPersistent", "1", "--requiredPersistentFloat", "1.0", "")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Reset the flags for the next command
requiredFlag.Changed = false
requiredPersistent.Changed = false
requiredPersistentFloat.Changed = false

expected = strings.Join([]string{
"realArg",
Expand Down
26 changes: 26 additions & 0 deletions shell_completions.go
Expand Up @@ -25,20 +25,46 @@ func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name)
}

// MarkFlagsRequired instructs the various shell completion implementations to
// prioritize the named flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func (c *Command) MarkFlagsRequired(names ...string) error {
return MarkFlagsRequired(c.Flags(), names...)
}

// MarkPersistentFlagRequired instructs the various shell completion implementations to
// prioritize the named persistent flag when performing completion,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name)
}

// MarkPersistentFlagsRequired instructs the various shell completion implementations to
// prioritize the named persistent flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func (c *Command) MarkPersistentFlagsRequired(names ...string) error {
return MarkFlagsRequired(c.PersistentFlags(), names...)
}

// MarkFlagRequired instructs the various shell completion implementations to
// prioritize the named flag when performing completion,
// and causes your command to report an error if invoked without the flag.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}

// MarkFlagsRequired instructs the various shell completion implementations to
// prioritize the named flags when performing completion,
// and causes your command to report an error if invoked without any of the flags.
func MarkFlagsRequired(flags *pflag.FlagSet, names ...string) error {
for _, name := range names {
if err := MarkFlagRequired(flags, name); err != nil {
return err
}
}
return nil
}

// MarkFlagFilename instructs the various shell completion implementations to
// limit completions for the named flag to the specified file extensions.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
Expand Down
18 changes: 18 additions & 0 deletions site/content/user_guide.md
Expand Up @@ -331,6 +331,24 @@ rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (re
rootCmd.MarkPersistentFlagRequired("region")
```

### Multiple Required flags

If your command has multiple required flags that are not [grouped](#flag-groups) to report an error
when one or more flags have not been set, mark them as required:
```go
rootCmd.Flags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
rootCmd.Flags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)")
rootCmd.MarkFlagsRequired("region", "failover")
```

Or, for multiple persistent flags:
```go
rootCmd.PersistentFlags().StringVarP(&Region, "region", "r", "", "AWS region (required)")
rootCmd.PersistentFlags().StringVarP(&Failover, "failover", "f", "", "AWS failover region (required)")
rootCmd.MarkPersistentFlagsRequired("region", "failover")
```


### Flag Groups

If you have different flags that must be provided together (e.g. if they provide the `--username` flag they MUST provide the `--password` flag as well) then
Expand Down