Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup context.go #1264

Merged
merged 1 commit into from Apr 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions app.go
Expand Up @@ -278,7 +278,7 @@ func (a *App) RunContext(ctx context.Context, arguments []string) (err error) {
return nil
}

cerr := checkRequiredFlags(a.Flags, context)
cerr := context.checkRequiredFlags(a.Flags)
if cerr != nil {
_ = ShowAppHelp(context)
return cerr
Expand Down Expand Up @@ -397,7 +397,7 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) {
}
}

cerr := checkRequiredFlags(a.Flags, context)
cerr := context.checkRequiredFlags(a.Flags)
if cerr != nil {
_ = ShowSubcommandHelp(context)
return cerr
Expand Down
2 changes: 1 addition & 1 deletion command.go
Expand Up @@ -127,7 +127,7 @@ func (c *Command) Run(ctx *Context) (err error) {
return nil
}

cerr := checkRequiredFlags(c.Flags, context)
cerr := context.checkRequiredFlags(c.Flags)
if cerr != nil {
_ = ShowCommandHelp(context, c.Name)
return cerr
Expand Down
130 changes: 31 additions & 99 deletions context.go
Expand Up @@ -2,9 +2,7 @@ package cli

import (
"context"
"errors"
"flag"
"fmt"
"strings"
)

Expand Down Expand Up @@ -53,20 +51,18 @@ func (c *Context) Set(name, value string) error {

// IsSet determines if the flag was actually set
func (c *Context) IsSet(name string) bool {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := lookupFlagSet(name, c); fs != nil {
isSet := false
fs.Visit(func(f *flag.Flag) {
if f.Name == name {
isSet = true
}
})
if isSet {
return true
if fs := c.lookupFlagSet(name); fs != nil {
isSet := false
fs.Visit(func(f *flag.Flag) {
if f.Name == name {
isSet = true
}
})
if isSet {
return true
}

f := lookupFlag(name, c)
f := c.lookupFlag(name)
if f == nil {
return false
}
Expand Down Expand Up @@ -108,7 +104,7 @@ func (c *Context) Lineage() []*Context {

// Value returns the value of the flag corresponding to `name`
func (c *Context) Value(name string) interface{} {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return fs.Lookup(name).Value.(flag.Getter).Get()
}
return nil
Expand All @@ -125,7 +121,7 @@ func (c *Context) NArg() int {
return c.Args().Len()
}

func lookupFlag(name string, ctx *Context) Flag {
func (ctx *Context) lookupFlag(name string) Flag {
for _, c := range ctx.Lineage() {
if c.Command == nil {
continue
Expand Down Expand Up @@ -153,7 +149,7 @@ func lookupFlag(name string, ctx *Context) Flag {
return nil
}

func lookupFlagSet(name string, ctx *Context) *flag.FlagSet {
func (ctx *Context) lookupFlagSet(name string) *flag.FlagSet {
for _, c := range ctx.Lineage() {
if f := c.flagSet.Lookup(name); f != nil {
return c.flagSet
Expand All @@ -163,89 +159,7 @@ func lookupFlagSet(name string, ctx *Context) *flag.FlagSet {
return nil
}

func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
switch ff.Value.(type) {
case Serializer:
_ = set.Set(name, ff.Value.(Serializer).Serialize())
default:
_ = set.Set(name, ff.Value.String())
}
}

func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
parts := f.Names()
if len(parts) == 1 {
continue
}
var ff *flag.Flag
for _, name := range parts {
name = strings.Trim(name, " ")
if visited[name] {
if ff != nil {
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
}
ff = set.Lookup(name)
}
}
if ff == nil {
continue
}
for _, name := range parts {
name = strings.Trim(name, " ")
if !visited[name] {
copyFlag(name, ff, set)
}
}
}
return nil
}

func makeFlagNameVisitor(names *[]string) func(*flag.Flag) {
return func(f *flag.Flag) {
nameParts := strings.Split(f.Name, ",")
name := strings.TrimSpace(nameParts[0])

for _, part := range nameParts {
part = strings.TrimSpace(part)
if len(part) > len(name) {
name = part
}
}

if name != "" {
*names = append(*names, name)
}
}
}

type requiredFlagsErr interface {
error
getMissingFlags() []string
}

type errRequiredFlags struct {
missingFlags []string
}

func (e *errRequiredFlags) Error() string {
numberOfMissingFlags := len(e.missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
}
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
}

func (e *errRequiredFlags) getMissingFlags() []string {
return e.missingFlags
}

func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
func (context *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr {
var missingFlags []string
for _, f := range flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
Expand Down Expand Up @@ -274,3 +188,21 @@ func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {

return nil
}

func makeFlagNameVisitor(names *[]string) func(*flag.Flag) {
return func(f *flag.Flag) {
nameParts := strings.Split(f.Name, ",")
name := strings.TrimSpace(nameParts[0])

for _, part := range nameParts {
part = strings.TrimSpace(part)
if len(part) > len(name) {
name = part
}
}

if name != "" {
*names = append(*names, name)
}
}
}
8 changes: 4 additions & 4 deletions context_test.go
Expand Up @@ -316,13 +316,13 @@ func TestContext_lookupFlagSet(t *testing.T) {
_ = set.Parse([]string{"--local-flag"})
_ = parentSet.Parse([]string{"--top-flag"})

fs := lookupFlagSet("top-flag", ctx)
fs := ctx.lookupFlagSet("top-flag")
expect(t, fs, parentCtx.flagSet)

fs = lookupFlagSet("local-flag", ctx)
fs = ctx.lookupFlagSet("local-flag")
expect(t, fs, ctx.flagSet)

if fs := lookupFlagSet("frob", ctx); fs != nil {
if fs := ctx.lookupFlagSet("frob"); fs != nil {
t.Fail()
}
}
Expand Down Expand Up @@ -576,7 +576,7 @@ func TestCheckRequiredFlags(t *testing.T) {
ctx.Command.Flags = test.flags

// logic under test
err := checkRequiredFlags(test.flags, ctx)
err := ctx.checkRequiredFlags(test.flags)

// assertions
if test.expectedAnError && err == nil {
Expand Down
22 changes: 22 additions & 0 deletions errors.go
Expand Up @@ -47,6 +47,28 @@ func (m *multiError) Errors() []error {
return errs
}

type requiredFlagsErr interface {
error
getMissingFlags() []string
}

type errRequiredFlags struct {
missingFlags []string
}

func (e *errRequiredFlags) Error() string {
numberOfMissingFlags := len(e.missingFlags)
if numberOfMissingFlags == 1 {
return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
}
joinedMissingFlags := strings.Join(e.missingFlags, ", ")
return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
}

func (e *errRequiredFlags) getMissingFlags() []string {
return e.missingFlags
}

// ErrorFormatter is the interface that will suitably format the error output
type ErrorFormatter interface {
Format(s fmt.State, verb rune)
Expand Down
43 changes: 43 additions & 0 deletions flag.go
@@ -1,6 +1,7 @@
package cli

import (
"errors"
"flag"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -130,6 +131,48 @@ func flagSet(name string, flags []Flag) (*flag.FlagSet, error) {
return set, nil
}

func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
switch ff.Value.(type) {
case Serializer:
_ = set.Set(name, ff.Value.(Serializer).Serialize())
default:
_ = set.Set(name, ff.Value.String())
}
}

func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
parts := f.Names()
if len(parts) == 1 {
continue
}
var ff *flag.Flag
for _, name := range parts {
name = strings.Trim(name, " ")
if visited[name] {
if ff != nil {
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
}
ff = set.Lookup(name)
}
}
if ff == nil {
continue
}
for _, name := range parts {
name = strings.Trim(name, " ")
if !visited[name] {
copyFlag(name, ff, set)
}
}
}
return nil
}

func visibleFlags(fl []Flag) []Flag {
var visible []Flag
for _, f := range fl {
Expand Down
2 changes: 1 addition & 1 deletion flag_bool.go
Expand Up @@ -87,7 +87,7 @@ func (f *BoolFlag) Apply(set *flag.FlagSet) error {
// Bool looks up the value of a local BoolFlag, returns
// false if not found
func (c *Context) Bool(name string) bool {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupBool(name, fs)
}
return false
Expand Down
2 changes: 1 addition & 1 deletion flag_duration.go
Expand Up @@ -86,7 +86,7 @@ func (f *DurationFlag) Apply(set *flag.FlagSet) error {
// Duration looks up the value of a local DurationFlag, returns
// 0 if not found
func (c *Context) Duration(name string) time.Duration {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupDuration(name, fs)
}
return 0
Expand Down
2 changes: 1 addition & 1 deletion flag_float64.go
Expand Up @@ -87,7 +87,7 @@ func (f *Float64Flag) Apply(set *flag.FlagSet) error {
// Float64 looks up the value of a local Float64Flag, returns
// 0 if not found
func (c *Context) Float64(name string) float64 {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupFloat64(name, fs)
}
return 0
Expand Down
2 changes: 1 addition & 1 deletion flag_float64_slice.go
Expand Up @@ -146,7 +146,7 @@ func (f *Float64SliceFlag) Apply(set *flag.FlagSet) error {
// Float64Slice looks up the value of a local Float64SliceFlag, returns
// nil if not found
func (c *Context) Float64Slice(name string) []float64 {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupFloat64Slice(name, fs)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion flag_generic.go
Expand Up @@ -89,7 +89,7 @@ func (f GenericFlag) Apply(set *flag.FlagSet) error {
// Generic looks up the value of a local GenericFlag, returns
// nil if not found
func (c *Context) Generic(name string) interface{} {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupGeneric(name, fs)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion flag_int.go
Expand Up @@ -87,7 +87,7 @@ func (f *IntFlag) Apply(set *flag.FlagSet) error {
// Int looks up the value of a local IntFlag, returns
// 0 if not found
func (c *Context) Int(name string) int {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupInt(name, fs)
}
return 0
Expand Down
2 changes: 1 addition & 1 deletion flag_int64.go
Expand Up @@ -86,7 +86,7 @@ func (f *Int64Flag) Apply(set *flag.FlagSet) error {
// Int64 looks up the value of a local Int64Flag, returns
// 0 if not found
func (c *Context) Int64(name string) int64 {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupInt64(name, fs)
}
return 0
Expand Down
2 changes: 1 addition & 1 deletion flag_int64_slice.go
Expand Up @@ -145,7 +145,7 @@ func (f *Int64SliceFlag) Apply(set *flag.FlagSet) error {
// Int64Slice looks up the value of a local Int64SliceFlag, returns
// nil if not found
func (c *Context) Int64Slice(name string) []int64 {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupInt64Slice(name, fs)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion flag_int_slice.go
Expand Up @@ -156,7 +156,7 @@ func (f *IntSliceFlag) Apply(set *flag.FlagSet) error {
// IntSlice looks up the value of a local IntSliceFlag, returns
// nil if not found
func (c *Context) IntSlice(name string) []int {
if fs := lookupFlagSet(name, c); fs != nil {
if fs := c.lookupFlagSet(name); fs != nil {
return lookupIntSlice(name, fs)
}
return nil
Expand Down