Skip to content

Commit

Permalink
Handle requiredFlags errors in the same way as parsed flag errors.
Browse files Browse the repository at this point in the history
At the moment they are handled differently, which means that you will get inconsistent output when you do not pass a required flag compared to when you pass in an incorrect flag.

Merge spf13/cobra#1504
  • Loading branch information
claeysn authored and hoshsadiq committed Feb 8, 2022
1 parent 9139e59 commit bd8b35c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ func (c *Command) execute(a []string) (err error) {
}

if err := c.validateRequiredFlags(); err != nil {
return err
return c.FlagErrorFunc()(c, err)
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
Expand Down
24 changes: 23 additions & 1 deletion command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"os"
Expand Down Expand Up @@ -781,7 +782,6 @@ func TestRequiredFlags(t *testing.T) {
c.Flags().String("foo2", "", "")
assertNoErr(t, c.MarkFlagRequired("foo2"))
c.Flags().String("bar", "", "")

expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")

_, err := executeCommand(c)
Expand All @@ -792,6 +792,28 @@ func TestRequiredFlags(t *testing.T) {
}
}

func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
c.Flags().String("foo1", "", "")
assertNoErr(t, c.MarkFlagRequired("foo1"))
silentError := "failed flag parsing"
c.SetFlagErrorFunc(func(c *Command, err error) error {
c.Println(err)
c.Println(c.UsageString())
return errors.New(silentError)
})
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")

output, err := executeCommand(c)
got := err.Error()

if got != silentError {
t.Errorf("Expected error %s but got %s", silentError, got)
}
checkStringContains(t, output, requiredFlagErrorMessage)
checkStringContains(t, output, c.UsageString())
}

func TestPersistentRequiredFlags(t *testing.T) {
parent := &Command{Use: "parent", Run: emptyRun}
parent.PersistentFlags().String("foo1", "", "")
Expand Down

0 comments on commit bd8b35c

Please sign in to comment.