Skip to content

Commit

Permalink
feat: add MarkFlagsMutuallyExclusiveAndRequired
Browse files Browse the repository at this point in the history
  • Loading branch information
baruchiro committed Jun 6, 2023
1 parent 284f410 commit 4ec9d93
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
51 changes: 49 additions & 2 deletions flag_groups.go
Expand Up @@ -23,8 +23,9 @@ import (
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
mutuallyExclusiveAndRequired = "cobra_annotation_mutually_exclusive_and_required"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -59,6 +60,22 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

// MarkFlagsMutuallyExclusiveAndRequired marks the given flags with annotations so that Cobra errors
// if the command is invoked without exactly one flag from the given set of flags.
func (c *Command) MarkFlagsMutuallyExclusiveAndRequired(flagNames ...string) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAndRequired, append(f.Annotations[mutuallyExclusiveAndRequired], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}

// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -72,9 +89,11 @@ func (c *Command) ValidateFlagGroups() error {
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveAndRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAndRequired, mutuallyExclusiveAndRequiredGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -83,6 +102,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateExclusiveAndRequiredFlagGroups(mutuallyExclusiveAndRequiredGroupStatus); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -163,6 +185,31 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateExclusiveAndRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
}
}

if len(set) == 0 {
// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("exactly one of the flags in the group [%v] must be set; none were set", flagList)
}
if len(set) > 1 {
// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("exactly one of the flags in the group [%v] must be set; %v were all set", flagList, set)
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand Down
28 changes: 21 additions & 7 deletions flag_groups_test.go
Expand Up @@ -43,13 +43,14 @@ func TestValidateFlagGroups(t *testing.T) {

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsExclusive []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
desc string
flagGroupsRequired []string
flagGroupsExclusive []string
flagGroupsExclusiveRequires []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
Expand All @@ -67,6 +68,16 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
}, {
desc: "Required exclusive group not satisfied",
flagGroupsExclusiveRequires: []string{"a b c"},
args: []string{"--d=foo"},
expectErr: "exactly one of the flags in the group [a b c] must be set; none were set",
}, {
desc: "Required exclusive group selected more than one",
flagGroupsExclusiveRequires: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "exactly one of the flags in the group [a b c] must be set; [a b] were all set",
}, {
desc: "Multiple required flag group not satisfied returns first error",
flagGroupsRequired: []string{"a b c", "a d"},
Expand Down Expand Up @@ -133,6 +144,9 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.flagGroupsExclusive {
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsExclusiveRequires {
c.MarkFlagsMutuallyExclusiveAndRequired(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.subCmdFlagGroupsRequired {
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
}
Expand Down

0 comments on commit 4ec9d93

Please sign in to comment.