Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Getting the value of a StringToString pflag #874

Merged
merged 7 commits into from May 9, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions viper.go
Expand Up @@ -1083,6 +1083,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return cast.ToIntSlice(res)
case "stringToString":
return parseStringToStringFlagValue(flag.ValueString())
default:
return flag.ValueString()
}
Expand Down Expand Up @@ -1158,6 +1160,8 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return cast.ToIntSlice(res)
case "stringToString":
return parseStringToStringFlagValue(flag.ValueString())
default:
return flag.ValueString()
}
Expand All @@ -1177,6 +1181,24 @@ func readAsCSV(val string) ([]string, error) {
return csvReader.Read()
}

func parseStringToStringFlagValue(val string) map[string]interface{} {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about copying the same function from pflag and keeping the implementation as close and possible?

https://github.com/spf13/pflag/blob/master/string_to_string.go#L79

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sagikazarmark i've made this change. i had to tweak the signature and return types slightly to work with GetStringMap that uses cast.ToStringMap

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing, thanks!

s := strings.TrimPrefix(val, "[")
s = strings.TrimSuffix(s, "]")
if s == "" {
return nil
}
elements := strings.Split(s, ",")
result := make(map[string]interface{}, len(elements))
for _, element := range elements {
pair := strings.SplitN(element, "=", 2)
if len(pair) != 2 {
return nil
}
result[pair[0]] = pair[1]
}
return result
}

// IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) }
Expand Down
47 changes: 47 additions & 0 deletions viper_test.go
Expand Up @@ -970,6 +970,53 @@ func TestBindPFlag(t *testing.T) {
assert.Equal(t, "testing_mutate", Get("testvalue"))
}

func TestBindPFlagStringToString(t *testing.T) {
tests := []struct {
Expected map[string]string
Value string
}{
{nil, ""},
{map[string]string{"yo": "hi"}, "yo=hi"},
{map[string]string{"yo": "hi", "oh": "hi=there"}, "yo=hi,oh=hi=there"},
{map[string]string{"yo": ""}, "yo="},
{map[string]string{"yo": "", "oh": "hi=there"}, "yo=,oh=hi=there"},
}

v := New() // create independent Viper object
defaultVal := map[string]string{}
v.SetDefault("stringtostring", defaultVal)

for _, testValue := range tests {
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
flagSet.StringToString("stringtostring", testValue.Expected, "test")

for _, changed := range []bool{true, false} {
flagSet.VisitAll(func(f *pflag.Flag) {
f.Value.Set(testValue.Value)
f.Changed = changed
})

err := v.BindPFlags(flagSet)
if err != nil {
t.Fatalf("error binding flag set, %v", err)
}

type TestMap struct {
StringToString map[string]string
}
val := &TestMap{}
if err := v.Unmarshal(val); err != nil {
t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err)
}
if changed {
assert.Equal(t, testValue.Expected, val.StringToString)
} else {
assert.Equal(t, defaultVal, val.StringToString)
}
}
}
}

func TestBoundCaseSensitivity(t *testing.T) {
assert.Equal(t, "brown", Get("eyes"))

Expand Down