Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Required flags #155

Merged
merged 11 commits into from
Aug 2, 2019
32 changes: 30 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,25 @@ func (a *App) Run(arguments []string) error {
set.SetOutput(ioutil.Discard)
err := set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set)
cerr := checkRequiredFlags(a.Flags, set)

context := NewContext(a, set, set)

if nerr != nil {
fmt.Println(nerr)
context := NewContext(a, set, set)
fmt.Println("")
ShowAppHelp(context)
fmt.Println("")
return nerr
}
context := NewContext(a, set, set)

if cerr != nil {
fmt.Println(cerr)
fmt.Println("")
ShowAppHelp(context)
fmt.Println("")
return cerr
}

if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
Expand Down Expand Up @@ -164,10 +175,13 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
set.SetOutput(ioutil.Discard)
err := set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
cerr := checkRequiredFlags(a.Flags, set)

context := NewContext(a, set, ctx.globalSet)

if nerr != nil {
fmt.Println(nerr)
fmt.Println("")
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
} else {
Expand All @@ -177,6 +191,20 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
return nerr
}

if cerr != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to DRY this up by combining with the above conditional. The only thing that differs is which error you print.

fmt.Println(cerr)
fmt.Println("")
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
fmt.Println("subcommands")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here?

} else {
ShowCommandHelp(ctx, context.Args().First())
fmt.Println("commands")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here as well?

}
fmt.Println("")
return cerr
}

if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowSubcommandHelp(context)
Expand Down
10 changes: 10 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ func (c Command) Run(ctx *Context) error {
fmt.Println("")
return nerr
}

cerr := checkRequiredFlags(c.Flags, set)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with respect to DRYing.

if cerr != nil {
fmt.Println(cerr)
fmt.Println("")
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return cerr
}

context := NewContext(ctx.App, set, ctx.globalSet)

if checkCommandCompletions(context, c.Name) {
Expand Down
18 changes: 18 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"errors"
"flag"
"fmt"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -337,3 +338,20 @@ func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
}
return nil
}

func checkRequiredFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})

for _, f := range flags {
if f.IsRequired() {
key := strings.Split(f.getName(), ",")[0]
if !visited[key] {
return fmt.Errorf("Required flag %s not set", f.getName())
}
}
}
return nil
}
144 changes: 101 additions & 43 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Flag interface {
// Apply Flag settings to the given flag set
Apply(*flag.FlagSet)
getName() string
IsRequired() bool
}

func flagSet(name string, flags []Flag) *flag.FlagSet {
Expand Down Expand Up @@ -61,14 +62,15 @@ type Generic interface {

// GenericFlag is the flag type for types implementing Generic
type GenericFlag struct {
Name string
Value Generic
Usage string
EnvVar string
Name string
Value Generic
Usage string
EnvVar string
Required bool
}

func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
}

func (f GenericFlag) Apply(set *flag.FlagSet) {
Expand All @@ -88,6 +90,10 @@ func (f GenericFlag) getName() string {
return f.Name
}

func (f GenericFlag) IsRequired() bool {
return f.Required
}

type StringSlice []string

func (f *StringSlice) Set(value string) error {
Expand All @@ -104,16 +110,17 @@ func (f *StringSlice) Value() []string {
}

type StringSliceFlag struct {
Name string
Value *StringSlice
Usage string
EnvVar string
Name string
Value *StringSlice
Usage string
EnvVar string
Required bool
}

func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}

func (f StringSliceFlag) Apply(set *flag.FlagSet) {
Expand All @@ -136,6 +143,10 @@ func (f StringSliceFlag) getName() string {
return f.Name
}

func (f StringSliceFlag) IsRequired() bool {
return f.Required
}

type IntSlice []int

func (f *IntSlice) Set(value string) error {
Expand All @@ -158,16 +169,17 @@ func (f *IntSlice) Value() []int {
}

type IntSliceFlag struct {
Name string
Value *IntSlice
Usage string
EnvVar string
Name string
Value *IntSlice
Usage string
EnvVar string
Required bool
}

func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}

func (f IntSliceFlag) Apply(set *flag.FlagSet) {
Expand All @@ -193,14 +205,19 @@ func (f IntSliceFlag) getName() string {
return f.Name
}

func (f IntSliceFlag) IsRequired() bool {
return f.Required
}

type BoolFlag struct {
Name string
Usage string
EnvVar string
Name string
Usage string
EnvVar string
Required bool
}

func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}

func (f BoolFlag) Apply(set *flag.FlagSet) {
Expand All @@ -223,14 +240,19 @@ func (f BoolFlag) getName() string {
return f.Name
}

func (f BoolFlag) IsRequired() bool {
return f.Required
}

type BoolTFlag struct {
Name string
Usage string
EnvVar string
Name string
Usage string
EnvVar string
Required bool
}

func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}

func (f BoolTFlag) Apply(set *flag.FlagSet) {
Expand All @@ -253,11 +275,16 @@ func (f BoolTFlag) getName() string {
return f.Name
}

func (f BoolTFlag) IsRequired() bool {
return f.Required
}

type StringFlag struct {
Name string
Value string
Usage string
EnvVar string
Name string
Value string
Usage string
EnvVar string
Required bool
}

func (f StringFlag) String() string {
Expand All @@ -270,7 +297,7 @@ func (f StringFlag) String() string {
fmtString = "%s %v\t%v"
}

return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
}

func (f StringFlag) Apply(set *flag.FlagSet) {
Expand All @@ -289,15 +316,20 @@ func (f StringFlag) getName() string {
return f.Name
}

func (f StringFlag) IsRequired() bool {
return f.Required
}

type IntFlag struct {
Name string
Value int
Usage string
EnvVar string
Name string
Value int
Usage string
EnvVar string
Required bool
}

func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}

func (f IntFlag) Apply(set *flag.FlagSet) {
Expand All @@ -319,15 +351,20 @@ func (f IntFlag) getName() string {
return f.Name
}

func (f IntFlag) IsRequired() bool {
return f.Required
}

type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
Name string
Value time.Duration
Usage string
EnvVar string
Required bool
}

func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}

func (f DurationFlag) Apply(set *flag.FlagSet) {
Expand All @@ -349,15 +386,20 @@ func (f DurationFlag) getName() string {
return f.Name
}

func (f DurationFlag) IsRequired() bool {
return f.Required
}

type Float64Flag struct {
Name string
Value float64
Usage string
EnvVar string
Name string
Value float64
Usage string
EnvVar string
Required bool
}

func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withHints(f.EnvVar, f.Required, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}

func (f Float64Flag) Apply(set *flag.FlagSet) {
Expand All @@ -379,6 +421,10 @@ func (f Float64Flag) getName() string {
return f.Name
}

func (f Float64Flag) IsRequired() bool {
return f.Required
}

func prefixFor(name string) (prefix string) {
if len(name) == 1 {
prefix = "-"
Expand Down Expand Up @@ -408,3 +454,15 @@ func withEnvHint(envVar, str string) string {
}
return str + envText
}

func withRequiredHint(isRequired bool, str string) string {
if isRequired {
return str + " (required)"
}

return str
}

func withHints(envVar string, isRequired bool, str string) string {
return withRequiredHint(isRequired, withEnvHint(envVar, str))
}