diff --git a/lib/decode/decode.go b/lib/decode/decode.go index 7a6534e64d9a..0461962c3c14 100644 --- a/lib/decode/decode.go +++ b/lib/decode/decode.go @@ -73,12 +73,28 @@ func translationsForType(to reflect.Type) map[string]string { translations := map[string]string{} for i := 0; i < to.NumField(); i++ { field := to.Field(i) + tags := fieldTags(field) + if tags.squash { + embedded := field.Type + if embedded.Kind() == reflect.Ptr { + embedded = embedded.Elem() + } + if embedded.Kind() != reflect.Struct { + // mapstructure will handle reporting this error + continue + } + + for k, v := range translationsForType(embedded) { + translations[k] = v + } + continue + } + tag, ok := field.Tag.Lookup("alias") if !ok { continue } - - canonKey := strings.ToLower(canonicalFieldKey(field)) + canonKey := strings.ToLower(tags.name) for _, alias := range strings.Split(tag, ",") { translations[strings.ToLower(alias)] = canonKey } @@ -86,19 +102,31 @@ func translationsForType(to reflect.Type) map[string]string { return translations } -func canonicalFieldKey(field reflect.StructField) string { +func fieldTags(field reflect.StructField) mapstructureFieldTags { tag, ok := field.Tag.Lookup("mapstructure") if !ok { - return field.Name + return mapstructureFieldTags{name: field.Name} + } + + tags := mapstructureFieldTags{name: field.Name} + parts := strings.Split(tag, ",") + if len(parts) == 0 { + return tags + } + if parts[0] != "" { + tags.name = parts[0] } - parts := strings.SplitN(tag, ",", 2) - switch { - case len(parts) < 1: - return field.Name - case parts[0] == "": - return field.Name + for _, part := range parts[1:] { + if part == "squash" { + tags.squash = true + } } - return parts[0] + return tags +} + +type mapstructureFieldTags struct { + name string + squash bool } // HookWeakDecodeFromSlice looks for []map[string]interface{} and []interface{} diff --git a/lib/decode/decode_test.go b/lib/decode/decode_test.go index 8c1e6da5c562..b8243233d124 100644 --- a/lib/decode/decode_test.go +++ b/lib/decode/decode_test.go @@ -1,6 +1,7 @@ package decode import ( + "fmt" "reflect" "testing" @@ -210,16 +211,29 @@ type translateExample struct { FieldWithMapstructureTag string `alias:"second" mapstructure:"field_with_mapstruct_tag"` FieldWithMapstructureTagOmit string `mapstructure:"field_with_mapstruct_omit,omitempty" alias:"third"` FieldWithEmptyTag string `mapstructure:"" alias:"forth"` + EmbeddedStruct `mapstructure:",squash"` + *PtrEmbeddedStruct `mapstructure:",squash"` + BadField string `mapstructure:",squash"` +} + +type EmbeddedStruct struct { + NextField string `alias:"next"` +} + +type PtrEmbeddedStruct struct { + OtherNextField string `alias:"othernext"` } func TestTranslationsForType(t *testing.T) { to := reflect.TypeOf(translateExample{}) actual := translationsForType(to) expected := map[string]string{ - "first": "fielddefaultcanonical", - "second": "field_with_mapstruct_tag", - "third": "field_with_mapstruct_omit", - "forth": "fieldwithemptytag", + "first": "fielddefaultcanonical", + "second": "field_with_mapstruct_tag", + "third": "field_with_mapstruct_omit", + "forth": "fieldwithemptytag", + "next": "nextfield", + "othernext": "othernextfield", } require.Equal(t, expected, actual) } @@ -389,3 +403,35 @@ service { } require.Equal(t, target, expected) } + +func TestFieldTags(t *testing.T) { + type testCase struct { + tags string + expected mapstructureFieldTags + } + + fn := func(t *testing.T, tc testCase) { + tag := fmt.Sprintf(`mapstructure:"%v"`, tc.tags) + field := reflect.StructField{ + Tag: reflect.StructTag(tag), + Name: "Original", + } + actual := fieldTags(field) + require.Equal(t, tc.expected, actual) + } + + var testCases = []testCase{ + {tags: "", expected: mapstructureFieldTags{name: "Original"}}, + {tags: "just-a-name", expected: mapstructureFieldTags{name: "just-a-name"}}, + {tags: "name,squash", expected: mapstructureFieldTags{name: "name", squash: true}}, + {tags: ",squash", expected: mapstructureFieldTags{name: "Original", squash: true}}, + {tags: ",omitempty,squash", expected: mapstructureFieldTags{name: "Original", squash: true}}, + {tags: "named,omitempty,squash", expected: mapstructureFieldTags{name: "named", squash: true}}, + } + + for _, tc := range testCases { + t.Run(tc.tags, func(t *testing.T) { + fn(t, tc) + }) + } +}