diff --git a/enum_var_test.go b/enum_var_test.go index 26189e5..3e7eafe 100644 --- a/enum_var_test.go +++ b/enum_var_test.go @@ -25,7 +25,7 @@ func TestEnumVarPositive(t *testing.T) { } err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, enumString, "type1") + assert.Equal(t, "type1", enumString) tearDown(t.Name()) } diff --git a/goflags.go b/goflags.go index 496ae6b..d54994f 100644 --- a/goflags.go +++ b/goflags.go @@ -154,8 +154,8 @@ func (flagSet *FlagSet) generateDefaultConfig() []byte { configBuffer.WriteString(dv) case flag.Value: configBuffer.WriteString(dv.String()) - case StringSlice: - configBuffer.WriteString(dv.String()) + case []string: + configBuffer.WriteString(ToString(dv)) } configBuffer.WriteString("\n\n") @@ -342,7 +342,7 @@ func (flagSet *FlagSet) IntVar(field *int, long string, defaultValue int, usage // StringSliceVarP adds a string slice flag with a shortname and longname // Use options to customize the behavior -func (flagSet *FlagSet) StringSliceVarP(field *StringSlice, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData { +func (flagSet *FlagSet) StringSliceVarP(field *StringSlice, long, short string, defaultValue []string, usage string, options Options) *FlagData { optionMap[field] = options for _, defaultItem := range defaultValue { values, _ := ToStringSlice(defaultItem, options) @@ -350,6 +350,8 @@ func (flagSet *FlagSet) StringSliceVarP(field *StringSlice, long, short string, _ = field.Set(value) } } + field.Default = true + flagData := &FlagData{ usage: usage, long: long, @@ -377,6 +379,8 @@ func (flagSet *FlagSet) StringSliceVarConfigOnly(field *StringSlice, long string for _, item := range defaultValue { _ = field.Set(item) } + field.Default = true + flagData := &FlagData{ usage: usage, long: long, @@ -426,6 +430,7 @@ func (flagSet *FlagSet) PortVarP(field *Port, long, short string, defaultValue [ for _, item := range defaultValue { _ = field.Set(item) } + field.Default = true flagData := &FlagData{ usage: usage, @@ -673,6 +678,10 @@ func createUsageDefaultValue(data *FlagData, currentFlag *flag.Flag, valueType r if !isZeroValue(currentFlag, currentFlag.DefValue) { defaultValueTemplate := " (default " switch valueType.String() { // ugly hack because "flag.stringValue" is not exported from the parent library + case "*goflags.StringSlice": + return defaultValueTemplate + ToString(data.defaultValue.([]string)) + ")" + case "goflags.StringSlice": + return defaultValueTemplate + ToString(data.defaultValue.([]string)) + ")" case "*flag.stringValue": defaultValueTemplate += "%q" default: @@ -685,21 +694,29 @@ func createUsageDefaultValue(data *FlagData, currentFlag *flag.Flag, valueType r } func createUsageTypeAndDescription(currentFlag *flag.Flag, valueType reflect.Type) string { - var result string + var ( + result string + usage string + flagDisplayType string + ) + flagDisplayType, usage = flag.UnquoteUsage(currentFlag) - flagDisplayType, usage := flag.UnquoteUsage(currentFlag) if len(flagDisplayType) > 0 { if flagDisplayType == "value" { // hardcoded in the goflags library - switch valueType.Kind() { - case reflect.Ptr: - pointerTypeElement := valueType.Elem() - switch pointerTypeElement.Kind() { - case reflect.Slice, reflect.Array: - switch pointerTypeElement.Elem().Kind() { - case reflect.String: - flagDisplayType = "string[]" - default: - flagDisplayType = "value[]" + if strings.Contains(valueType.String(), "StringSlice") { + flagDisplayType = "string[]" + } else { + switch valueType.Kind() { + case reflect.Ptr: + pointerTypeElement := valueType.Elem() + switch pointerTypeElement.Kind() { + case reflect.Slice, reflect.Array: + switch pointerTypeElement.Elem().Kind() { + case reflect.String: + flagDisplayType = "string[]" + default: + flagDisplayType = "value[]" + } } } } diff --git a/goflags_test.go b/goflags_test.go index b2be42c..057346f 100644 --- a/goflags_test.go +++ b/goflags_test.go @@ -70,7 +70,7 @@ duration-value: 1h` require.Nil(t, err, "could not merge temporary config") require.Equal(t, "test", data, "could not get correct string") - require.Equal(t, StringSlice{"test", "test2"}, data2, "could not get correct string slice") + require.Equal(t, StringSlice{Value: []string{"test", "test2"}}, data2, "could not get correct string slice") require.Equal(t, 543, data3, "could not get correct int") require.Equal(t, true, data4, "could not get correct bool") require.Equal(t, time.Hour, data5, "could not get correct duration") @@ -146,7 +146,7 @@ BOOLEAN: -bool-with-default-value Bool with default value example (default true) -bwdv, -bool-with-default-value2 Bool with default value example #2 (default true) ` - assert.Equal(t, actual, expected) + assert.Equal(t, expected, actual) tearDown(t.Name()) } @@ -234,7 +234,7 @@ func TestParseStringSlice(t *testing.T) { err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, StringSlice{header1, header2, header3}, stringSlice) + assert.Equal(t, StringSlice{Value: []string{header1, header2, header3}}, stringSlice) tearDown(t.Name()) } @@ -258,7 +258,7 @@ func TestParseCommaSeparatedStringSlice(t *testing.T) { err := flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, csStringSlice, StringSlice{value1, value2, value3}) + assert.Equal(t, csStringSlice, StringSlice{Value: []string{value1, value2, value3}}) tearDown(t.Name()) } @@ -288,7 +288,7 @@ value3` err = flagSet.Parse() assert.Nil(t, err) - assert.Equal(t, csStringSlice, StringSlice{value1, value2, value3}) + assert.Equal(t, csStringSlice, StringSlice{Value: []string{value1, value2, value3}}) tearDown(t.Name()) } @@ -312,7 +312,39 @@ config-only: err = flagSet.MergeConfigFile("test.yaml") require.Nil(t, err, "could not merge temporary config") - require.Equal(t, StringSlice{"test", "test2"}, data, "could not get correct string slice") + require.Equal(t, StringSlice{Value: []string{"test", "test2"}}, data, "could not get correct string slice") + tearDown(t.Name()) +} + +func TestSetDefaultStringSliceValue(t *testing.T) { + var data StringSlice + flagSet := NewFlagSet() + flagSet.StringSliceVar(&data, "test", []string{"A,A,A"}, "Default value for a test flag example", CommaSeparatedStringSliceOptions) + err := flagSet.CommandLine.Parse([]string{"-test", "item1"}) + require.Nil(t, err) + require.Equal(t, StringSlice{Value: []string{"item1"}}, data, "could not get correct string slice") + + var data2 StringSlice + flagSet2 := NewFlagSet() + flagSet2.StringSliceVar(&data2, "test", []string{"A"}, "Default value for a test flag example", CommaSeparatedStringSliceOptions) + err = flagSet2.CommandLine.Parse([]string{"-test", "item1,item2"}) + require.Nil(t, err) + require.Equal(t, StringSlice{Value: []string{"item1", "item2"}}, data2, "could not get correct string slice") + + var data3 StringSlice + flagSet3 := NewFlagSet() + flagSet3.StringSliceVar(&data3, "test", []string{}, "Default value for a test flag example", CommaSeparatedStringSliceOptions) + err = flagSet3.CommandLine.Parse([]string{"-test", "item1,item2"}) + require.Nil(t, err) + require.Equal(t, StringSlice{Value: []string{"item1", "item2"}}, data3, "could not get correct string slice") + + var data4 StringSlice + flagSet4 := NewFlagSet() + flagSet4.StringSliceVar(&data4, "test", nil, "Default value for a test flag example", CommaSeparatedStringSliceOptions) + err = flagSet4.CommandLine.Parse([]string{"-test", "item1,\"item2\""}) + require.Nil(t, err) + require.Equal(t, StringSlice{Value: []string{"item1", "item2"}}, data4, "could not get correct string slice") + tearDown(t.Name()) } diff --git a/port.go b/port.go index 4a4b172..1d2c800 100644 --- a/port.go +++ b/port.go @@ -10,7 +10,8 @@ import ( // Port is a list of unique ports in a normalized format type Port struct { - kv map[int]struct{} + kv map[int]struct{} + Default bool } func (port Port) String() string { @@ -28,6 +29,10 @@ func (port Port) String() string { // Set inserts a value to the port map. A number of formats are accepted. func (port *Port) Set(value string) error { + if port.Default { + port.kv = map[int]struct{}{} + port.Default = false + } if port.kv == nil { port.kv = make(map[int]struct{}) } diff --git a/port_test.go b/port_test.go index b473112..96e91bf 100644 --- a/port_test.go +++ b/port_test.go @@ -1,6 +1,7 @@ package goflags import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -60,3 +61,31 @@ func TestPortType(t *testing.T) { require.ElementsMatch(t, port.AsPorts(), []int{443, 53}, "could not get correct ports") }) } + +func TestSetDefaultPortValue(t *testing.T) { + var data Port + flagSet := NewFlagSet() + flagSet.PortVarP(&data, "port", "p", []string{"1,3"}, "Default value for a test flag example") + err := flagSet.CommandLine.Parse([]string{"-p", "11"}) + require.Nil(t, err) + fmt.Println(data) + require.Equal(t, Port{kv: map[int]struct{}{11: {}}}, data, "could not get correct string slice") + + var data2 Port + flagSet2 := NewFlagSet() + flagSet2.PortVarP(&data2, "port", "p", []string{"1,3"}, "Default value for a test flag example") + err = flagSet2.CommandLine.Parse([]string{"-p", "11,12"}) + require.Nil(t, err) + fmt.Println(data2) + require.Equal(t, Port{kv: map[int]struct{}{11: {}, 12: {}}}, data2, "could not get correct string slice") + + var data3 Port + flagSet3 := NewFlagSet() + flagSet3.PortVarP(&data3, "port", "p", nil, "Default value for a test flag example") + err = flagSet3.CommandLine.Parse([]string{"-p", "11,12"}) + fmt.Println(data2) + require.Nil(t, err) + require.Equal(t, Port{kv: map[int]struct{}{11: {}, 12: {}}}, data3, "could not get correct string slice") + + tearDown(t.Name()) +} diff --git a/string_slice.go b/string_slice.go index fd5449c..e02c4ff 100644 --- a/string_slice.go +++ b/string_slice.go @@ -7,10 +7,17 @@ func init() { } // StringSlice is a slice of strings -type StringSlice []string +type StringSlice struct { + Value []string + Default bool +} // Set appends a value to the string slice. func (stringSlice *StringSlice) Set(value string) error { + if stringSlice.Default { + stringSlice.Value = []string{} + stringSlice.Default = false + } option, ok := optionMap[stringSlice] if !ok { option = StringSliceOptions @@ -19,10 +26,10 @@ func (stringSlice *StringSlice) Set(value string) error { if err != nil { return err } - *stringSlice = append(*stringSlice, values...) + stringSlice.Value = append(stringSlice.Value, values...) return nil } func (stringSlice StringSlice) String() string { - return ToString(stringSlice) + return ToString(stringSlice.Value) } diff --git a/string_slice_test.go b/string_slice_test.go index 5727f18..50b93c6 100644 --- a/string_slice_test.go +++ b/string_slice_test.go @@ -48,7 +48,7 @@ func TestNormalizedStringSlicePositive(t *testing.T) { result, err := ToStringSlice(value, NormalizedStringSliceOptions) fmt.Println(result) assert.Nil(t, err) - assert.Equal(t, result, expected) + assert.Equal(t, expected, result) } }