Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MarkFlagsMutuallyExclusiveAndRequired #1972

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 49 additions & 2 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
mutuallyExclusiveAndRequired = "cobra_annotation_mutually_exclusive_and_required"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -59,6 +60,22 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

// MarkFlagsMutuallyExclusiveAndRequired marks the given flags with annotations so that Cobra errors
// if the command is invoked without exactly one flag from the given set of flags.
func (c *Command) MarkFlagsMutuallyExclusiveAndRequired(flagNames ...string) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAndRequired, append(f.Annotations[mutuallyExclusiveAndRequired], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}

// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -72,9 +89,11 @@ func (c *Command) ValidateFlagGroups() error {
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveAndRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAndRequired, mutuallyExclusiveAndRequiredGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -83,6 +102,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateExclusiveAndRequiredFlagGroups(mutuallyExclusiveAndRequiredGroupStatus); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -163,6 +185,31 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateExclusiveAndRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
}
}

if len(set) == 0 {
// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("exactly one of the flags in the group [%v] must be set; none were set", flagList)
}
if len(set) > 1 {
// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("exactly one of the flags in the group [%v] must be set; %v were all set", flagList, set)
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand Down
28 changes: 21 additions & 7 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ func TestValidateFlagGroups(t *testing.T) {

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsExclusive []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
desc string
flagGroupsRequired []string
flagGroupsExclusive []string
flagGroupsExclusiveRequires []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
Expand All @@ -67,6 +68,16 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
}, {
desc: "Required exclusive group not satisfied",
flagGroupsExclusiveRequires: []string{"a b c"},
args: []string{"--d=foo"},
expectErr: "exactly one of the flags in the group [a b c] must be set; none were set",
}, {
desc: "Required exclusive group selected more than one",
flagGroupsExclusiveRequires: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "exactly one of the flags in the group [a b c] must be set; [a b] were all set",
}, {
desc: "Multiple required flag group not satisfied returns first error",
flagGroupsRequired: []string{"a b c", "a d"},
Expand Down Expand Up @@ -133,6 +144,9 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.flagGroupsExclusive {
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsExclusiveRequires {
c.MarkFlagsMutuallyExclusiveAndRequired(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.subCmdFlagGroupsRequired {
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
}
Expand Down