Skip to content

Commit

Permalink
Add optional Opt{} param to basicflag (#259)
Browse files Browse the repository at this point in the history
* Add optional `Opt{}` param that takes a koanf instance to match posflag's default value behaviour in basic flag. Closes #255.
  • Loading branch information
knadh committed Dec 19, 2023
1 parent 27e15c0 commit 09d28ae
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 41 deletions.
85 changes: 77 additions & 8 deletions providers/basicflag/basicflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,63 @@ import (
"github.com/knadh/koanf/maps"
)

// Opt represents optional options (yup) passed to the provider.
type Opt struct {
KeyMap KoanfIntf
}

// KoanfIntf is an interface that represents a small subset of methods
// used by this package from Koanf{}.
type KoanfIntf interface {
Exists(string) bool
}

// Pflag implements a pflag command line provider.
type Pflag struct {
delim string
flagset *flag.FlagSet
cb func(key string, value string) (string, interface{})
opt *Opt
}

// Provider returns a commandline flags provider that returns
// a nested map[string]interface{} of environment variable where the
// nesting hierarchy of keys are defined by delim. For instance, the
// delim "." will convert the key `parent.child.key: 1`
// to `{parent: {child: {key: 1}}}`.
func Provider(f *flag.FlagSet, delim string) *Pflag {
return &Pflag{
//
// It takes an optional (but recommended) Opt{} argument containing a Koanf instance.
// It checks if the defined flags have been set by other providers (e.g., a config file).
// If not, default flag values are merged. If they exist, flag values are merged only if
// explicitly set in the command line. The function is variadic to maintain backward compatibility.
// See https://github.com/knadh/koanf/issues/255
func Provider(f *flag.FlagSet, delim string, opt ...*Opt) *Pflag {
pf := &Pflag{
flagset: f,
delim: delim,
}

if len(opt) > 0 {
pf.opt = opt[0]
}

return pf
}

// ProviderWithValue works exactly the same as Provider except the callback
// takes a (key, value) with the variable name and value and allows you
// to modify both. This is useful for cases where you may want to return
// other types like a string slice instead of just a string.
func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value string) (string, interface{})) *Pflag {
//
// It takes an optional Opt{} (but recommended) argument with a Koanf instance (opt.KeyMap) to see if the
// the flags defined have been set from other providers, for instance,
// a config file. If they are not, then the default values of the flags
// are merged. If they do exist, the flag values are not merged but only
// the values that have been explicitly set in the command line are merged.
// It is a variadic function as a hack to ensure backwards compatibility with the
// function definition.
// See https://github.com/knadh/koanf/issues/255
func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value string) (string, interface{}), ko ...KoanfIntf) *Pflag {
return &Pflag{
flagset: f,
delim: delim,
Expand All @@ -42,18 +75,54 @@ func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value

// Read reads the flag variables and returns a nested conf map.
func (p *Pflag) Read() (map[string]interface{}, error) {
var changed map[string]struct{}

// Prepare a map of flags that have been explicitly set by the user as aa KeyMap instance of Koanf
// has been provided.
if p.opt != nil && p.opt.KeyMap != nil {
changed = map[string]struct{}{}

p.flagset.Visit(func(f *flag.Flag) {
key := f.Name
if p.cb != nil {
key, _ = p.cb(f.Name, "")
}
if key == "" {
return
}

changed[key] = struct{}{}
})
}

mp := make(map[string]interface{})
p.flagset.VisitAll(func(f *flag.Flag) {
var (
key = f.Name
val interface{} = f.Value.String()
)
if p.cb != nil {
key, value := p.cb(f.Name, f.Value.String())
k, v := p.cb(f.Name, f.Value.String())
// If the callback blanked the key, it should be omitted
if key == "" {
if k == "" {
return
}
mp[key] = value
} else {
mp[f.Name] = f.Value.String()

key = k
val = v
}

// If the default value of the flag was never changed by the user,
// it should not override the value in the conf map (if it exists in the first place).
if changed != nil {
if _, ok := changed[key]; !ok {
if p.opt.KeyMap.Exists(key) {
return
}
}
}

mp[key] = val
})
return maps.Unflatten(mp, p.delim), nil
}
Expand Down
102 changes: 69 additions & 33 deletions tests/koanf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,9 @@ func TestLoadMerge(t *testing.T) {
func TestFlags(t *testing.T) {
var (
assert = assert.New(t)
k = koanf.New(delim)
def = koanf.New(delim)
)
assert.Nil(k.Load(file.Provider(mockJSON), json.Parser()), "error loading file")
k2 := k.Copy()
k3 := k.Copy()
assert.Nil(def.Load(file.Provider(mockJSON), json.Parser()), "error loading file")

// Override with the posflag provider.
f := pflag.NewFlagSet("test", pflag.ContinueOnError)
Expand All @@ -610,41 +608,79 @@ func TestFlags(t *testing.T) {

// Initialize the provider with the Koanf instance passed where default values
// will merge if the keys are not present in the conf map.
assert.Nil(k.Load(posflag.Provider(f, ".", k), nil), "error loading posflag")
assert.Equal("flag", k.String("parent1.child1.type"), "types don't match")
assert.Equal("flag", k.String("flagkey"), "value doesn't match")
assert.NotEqual("flag", k.String("parent1.name"), "value doesn't match")
assert.Equal([]string{"a", "b", "c"}, k.Strings("stringslice"), "value doesn't match")
assert.Equal([]int{1, 2, 3}, k.Ints("intslice"), "value doesn't match")
{
k := def.Copy()
assert.Nil(k.Load(posflag.Provider(f, ".", k), nil), "error loading posflag")
assert.Equal("flag", k.String("parent1.child1.type"), "types don't match")
assert.Equal("flag", k.String("flagkey"), "value doesn't match")
assert.NotEqual("flag", k.String("parent1.name"), "value doesn't match")
assert.Equal([]string{"a", "b", "c"}, k.Strings("stringslice"), "value doesn't match")
assert.Equal([]int{1, 2, 3}, k.Ints("intslice"), "value doesn't match")
}

// Test the posflag provider can mutate the value to upper case
assert.Nil(k3.Load(posflag.ProviderWithValue(f, ".", nil, func(k string, v string) (string, interface{}) {
return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v)
}), nil), "error loading posflag")
assert.Equal("FLAG", k3.String("parent1.child1.type"), "types don't match")
{
k := def.Copy()
assert.Nil(k.Load(posflag.ProviderWithValue(f, ".", nil, func(k string, v string) (string, interface{}) {
return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v)
}), nil), "error loading posflag")
assert.Equal("FLAG", k.String("parent1.child1.type"), "types don't match")
}

// Test without passing the Koanf instance where default values will not merge.
assert.Nil(k2.Load(posflag.Provider(f, ".", nil), nil), "error loading posflag")
assert.Equal("flag", k2.String("parent1.child1.type"), "types don't match")
assert.Equal("", k2.String("flagkey"), "value doesn't match")
assert.NotEqual("", k2.String("parent1.name"), "value doesn't match")

// Override with the flag provider.
bf := flag.NewFlagSet("test", flag.ContinueOnError)
bf.String("parent1.child1.type", "flag", "")
bf.Set("parent1.child1.type", "basicflag")
assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag")
assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match")
{
k := def.Copy()
assert.Nil(k.Load(posflag.Provider(f, ".", nil), nil), "error loading posflag")
assert.Equal("flag", k.String("parent1.child1.type"), "types don't match")
assert.Equal("", k.String("flagkey"), "value doesn't match")
assert.NotEqual("", k.String("parent1.name"), "value doesn't match")
}

// Test the basicflag provider can mutate the value to upper case
bf2 := flag.NewFlagSet("test", flag.ContinueOnError)
bf2.String("parent1.child1.type", "flag", "")
bf2.Set("parent1.child1.type", "basicflag")
assert.Nil(k.Load(basicflag.ProviderWithValue(bf2, ".", func(k string, v string) (string, interface{}) {
return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v)
}), nil), "error loading basicflag")
assert.Equal("BASICFLAG", k.String("parent1.child1.type"), "types don't match")
// Override with the basicflag provider.
{
k := def.Copy()
bf := flag.NewFlagSet("test", flag.ContinueOnError)
bf.String("parent1.child1.type", "flag", "")
bf.String("parent2.child2.name", "override-default", "")
bf.Set("parent1.child1.type", "basicflag")
assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag")
assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match")
assert.Equal("override-default", k.String("parent2.child2.name"), "basicflag default value override failed")
}

// No defualt-value override behaviour.
{
k := def.Copy()
bf := flag.NewFlagSet("test", flag.ContinueOnError)
bf.String("parent1.child1.name", "override-default", "")
bf.String("parent2.child2.name", "override-default", "")
bf.Set("parent2.child2.name", "custom")
assert.Nil(k.Load(basicflag.Provider(bf, ".", &basicflag.Opt{KeyMap: def}), nil), "error loading basicflag")
assert.Equal("child1", k.String("parent1.child1.name"), "basicflag default overwrote")
assert.Equal("custom", k.String("parent2.child2.name"), "basicflag set failed")
}

// Override with the basicflag provider.
{
k := def.Copy()
bf := flag.NewFlagSet("test", flag.ContinueOnError)
bf.String("parent1.child1.type", "flag", "")
bf.Set("parent1.child1.type", "basicflag")
assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag")
assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match")
}

// Test the basicflag provider can mutate the value to upper case
{
k := def.Copy()
bf := flag.NewFlagSet("test", flag.ContinueOnError)
bf.String("parent1.child1.type", "flag", "")
bf.Set("parent1.child1.type", "basicflag")
assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".", func(k string, v string) (string, interface{}) {
return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v)
}), nil), "error loading basicflag")
assert.Equal("BASICFLAG", k.String("parent1.child1.type"), "types don't match")
}
}

func TestConfMapValues(t *testing.T) {
Expand Down

0 comments on commit 09d28ae

Please sign in to comment.