diff --git a/command.go b/command.go index ee5365bcb..0f4511f38 100644 --- a/command.go +++ b/command.go @@ -863,6 +863,10 @@ func (c *Command) execute(a []string) (err error) { if err := c.validateRequiredFlags(); err != nil { return err } + if err := c.validateFlagGroups(); err != nil { + return err + } + if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { return err diff --git a/flag_groups.go b/flag_groups.go new file mode 100644 index 000000000..2cbea113e --- /dev/null +++ b/flag_groups.go @@ -0,0 +1,135 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "fmt" + "sort" + "strings" + + flag "github.com/spf13/pflag" +) + +const ( + RequiredAsGroup = "cobra_annotation_required_if_others_set" + MutuallyExclusive = "cobra_annotation_mutually_exclusive" +) + +func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f.Annotations == nil { + f.Annotations = map[string][]string{} + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + f.Annotations[RequiredAsGroup] = append(f.Annotations[RequiredAsGroup], strings.Join(flagNames, " ")) + } +} + +func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f.Annotations == nil { + f.Annotations = map[string][]string{} + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + f.Annotations[MutuallyExclusive] = append(f.Annotations[MutuallyExclusive], strings.Join(flagNames, " ")) + } +} + +// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic. +func (c *Command) validateFlagGroups() error { + if c.DisableFlagParsing { + return nil + } + + flags := c.Flags() + + // groupStatus format is the list of flags as a unique ID, + // then a map of each flag name and whether it is set or not. + groupStatus := map[string]map[string]bool{} + mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + flags.VisitAll(func(pflag *flag.Flag) { + groupInfo, found := pflag.Annotations[RequiredAsGroup] + if found { + // If not tracking a group, start. + for _, group := range groupInfo { + if groupStatus[group] == nil { + groupStatus[group] = map[string]bool{} + // Track each flag by name. + flagnames := strings.Split(group, " ") + for _, name := range flagnames { + groupStatus[group][name] = false + } + } + + // Record we've seen this flag for each group its in. + groupStatus[group][pflag.Name] = pflag.Changed + } + } + + groupInfo, found = pflag.Annotations[MutuallyExclusive] + if found { + // If not tracking a group, start. + for _, group := range groupInfo { + if mutuallyExclusiveGroupStatus[group] == nil { + mutuallyExclusiveGroupStatus[group] = map[string]bool{} + // Track each flag by name. + flagnames := strings.Split(group, " ") + for _, name := range flagnames { + mutuallyExclusiveGroupStatus[group][name] = false + } + } + + // Record we've seen this flag for each group its in. + mutuallyExclusiveGroupStatus[group][pflag.Name] = pflag.Changed + } + } + }) + + // Now review the groups and form errors as needed. + errMsgs := []string{} + for flagList, flagnameAndStatus := range groupStatus { + unset := []string{} + for flagname, isSet := range flagnameAndStatus { + if !isSet { + unset = append(unset, flagname) + } + } + if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { + continue + } + sort.Strings(unset) + errMsgs = append(errMsgs, fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)) + } + + for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { + set := []string{} + for flagname, isSet := range flagnameAndStatus { + if isSet { + set = append(set, flagname) + } + } + if len(set) == 0 || len(set) == 1 { + continue + } + sort.Strings(set) + errMsgs = append(errMsgs, fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)) + } + + if len(errMsgs) > 0 { + return fmt.Errorf(strings.Join(errMsgs, `, `)) + } + return nil +} diff --git a/flag_groups_test.go b/flag_groups_test.go new file mode 100644 index 000000000..d10927bc6 --- /dev/null +++ b/flag_groups_test.go @@ -0,0 +1,91 @@ +// Copyright © 2022 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "strings" + "testing" +) + +func TestValidateFlagGroups(t *testing.T) { + getCmd := func() *Command { + c := &Command{ + Use: "testcmd", + Run: func(cmd *Command, args []string) { + }} + // Define lots of flags to utilize for testing. + for _, v := range []string{"a", "b", "c", "d", "e", "f", "g"} { + c.Flags().String(v, "", "") + } + return c + } + + // Each test case uses a unique command from the function above. + testcases := []struct { + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + args []string + expectErr string + }{ + { + desc: "No flags no problem", + }, { + desc: "No flags no problem even with conflicting groups", + flagGroupsRequired: []string{"a b"}, + flagGroupsExclusive: []string{"a b"}, + }, { + desc: "Required flag group not satisfied", + flagGroupsRequired: []string{"a b c"}, + args: []string{"--a=foo"}, + expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + }, { + desc: "Exclusive flag group not satisfied", + flagGroupsExclusive: []string{"a b c"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "Multiple required flag group not satisfied", + flagGroupsRequired: []string{"a b c", "a d"}, + args: []string{"--c=foo", "--d=foo"}, + expectErr: "if any flags in the group [a b c] are set they must all be set; missing [a b], if any flags in the group [a d] are set they must all be set; missing [a]", + }, { + desc: "Multiple exclusive flag group not satisfied", + flagGroupsExclusive: []string{"a b c", "a d"}, + args: []string{"testcmd", "--a=foo", "--c=foo", "--d=foo"}, + expectErr: "if any flags in the group [a b c] are set none of the others can be; [a c] were all set, if any flags in the group [a d] are set none of the others can be; [a d] were all set", + }, + } + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + c := getCmd() + for _, flagGroup := range tc.flagGroupsRequired { + c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.flagGroupsExclusive { + c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) + } + c.SetArgs(tc.args) + err := c.Execute() + switch { + case err == nil && len(tc.expectErr) > 0: + t.Errorf("Expected error %q but got nil", tc.expectErr) + case err == nil && len(tc.expectErr) == 0: + case err != nil && err.Error() == tc.expectErr: + case err != nil && err.Error() != tc.expectErr: + t.Errorf("Expected error %q but got %v", tc.expectErr, err) + } + }) + } +} diff --git a/go.sum b/go.sum index 431058ed0..6d5345968 100644 --- a/go.sum +++ b/go.sum @@ -2,17 +2,11 @@ github.com/cpuguy83/go-md2man/v2 v2.0.1 h1:r/myEWzV9lfsM1tFLgDyu0atFtJ1fXn261LKY github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=