Skip to content

Commit

Permalink
lib/decode: fix hook to work with embedded squash struct
Browse files Browse the repository at this point in the history
The decode hook is not call for the embedded squashed struct, so we need to recurse when we
find squash tags.

See mitchellh/mapstructure#226
  • Loading branch information
dnephin committed Sep 22, 2021
1 parent 1e3ba26 commit d2274df
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 15 deletions.
50 changes: 39 additions & 11 deletions lib/decode/decode.go
Expand Up @@ -73,32 +73,60 @@ func translationsForType(to reflect.Type) map[string]string {
translations := map[string]string{}
for i := 0; i < to.NumField(); i++ {
field := to.Field(i)
tags := fieldTags(field)
if tags.squash {
embedded := field.Type
if embedded.Kind() == reflect.Ptr {
embedded = embedded.Elem()
}
if embedded.Kind() != reflect.Struct {
// mapstructure will handle reporting this error
continue
}

for k, v := range translationsForType(embedded) {
translations[k] = v
}
continue
}

tag, ok := field.Tag.Lookup("alias")
if !ok {
continue
}

canonKey := strings.ToLower(canonicalFieldKey(field))
canonKey := strings.ToLower(tags.name)
for _, alias := range strings.Split(tag, ",") {
translations[strings.ToLower(alias)] = canonKey
}
}
return translations
}

func canonicalFieldKey(field reflect.StructField) string {
func fieldTags(field reflect.StructField) mapstructureFieldTags {
tag, ok := field.Tag.Lookup("mapstructure")
if !ok {
return field.Name
return mapstructureFieldTags{name: field.Name}
}

tags := mapstructureFieldTags{name: field.Name}
parts := strings.Split(tag, ",")
if len(parts) == 0 {
return tags
}
if parts[0] != "" {
tags.name = parts[0]
}
parts := strings.SplitN(tag, ",", 2)
switch {
case len(parts) < 1:
return field.Name
case parts[0] == "":
return field.Name
for _, part := range parts[1:] {
if part == "squash" {
tags.squash = true
}
}
return parts[0]
return tags
}

type mapstructureFieldTags struct {
name string
squash bool
}

// HookWeakDecodeFromSlice looks for []map[string]interface{} and []interface{}
Expand Down
54 changes: 50 additions & 4 deletions lib/decode/decode_test.go
@@ -1,6 +1,7 @@
package decode

import (
"fmt"
"reflect"
"testing"

Expand Down Expand Up @@ -210,16 +211,29 @@ type translateExample struct {
FieldWithMapstructureTag string `alias:"second" mapstructure:"field_with_mapstruct_tag"`
FieldWithMapstructureTagOmit string `mapstructure:"field_with_mapstruct_omit,omitempty" alias:"third"`
FieldWithEmptyTag string `mapstructure:"" alias:"forth"`
EmbeddedStruct `mapstructure:",squash"`
*PtrEmbeddedStruct `mapstructure:",squash"`
BadField string `mapstructure:",squash"`
}

type EmbeddedStruct struct {
NextField string `alias:"next"`
}

type PtrEmbeddedStruct struct {
OtherNextField string `alias:"othernext"`
}

func TestTranslationsForType(t *testing.T) {
to := reflect.TypeOf(translateExample{})
actual := translationsForType(to)
expected := map[string]string{
"first": "fielddefaultcanonical",
"second": "field_with_mapstruct_tag",
"third": "field_with_mapstruct_omit",
"forth": "fieldwithemptytag",
"first": "fielddefaultcanonical",
"second": "field_with_mapstruct_tag",
"third": "field_with_mapstruct_omit",
"forth": "fieldwithemptytag",
"next": "nextfield",
"othernext": "othernextfield",
}
require.Equal(t, expected, actual)
}
Expand Down Expand Up @@ -389,3 +403,35 @@ service {
}
require.Equal(t, target, expected)
}

func TestFieldTags(t *testing.T) {
type testCase struct {
tags string
expected mapstructureFieldTags
}

fn := func(t *testing.T, tc testCase) {
tag := fmt.Sprintf(`mapstructure:"%v"`, tc.tags)
field := reflect.StructField{
Tag: reflect.StructTag(tag),
Name: "Original",
}
actual := fieldTags(field)
require.Equal(t, tc.expected, actual)
}

var testCases = []testCase{
{tags: "", expected: mapstructureFieldTags{name: "Original"}},
{tags: "just-a-name", expected: mapstructureFieldTags{name: "just-a-name"}},
{tags: "name,squash", expected: mapstructureFieldTags{name: "name", squash: true}},
{tags: ",squash", expected: mapstructureFieldTags{name: "Original", squash: true}},
{tags: ",omitempty,squash", expected: mapstructureFieldTags{name: "Original", squash: true}},
{tags: "named,omitempty,squash", expected: mapstructureFieldTags{name: "named", squash: true}},
}

for _, tc := range testCases {
t.Run(tc.tags, func(t *testing.T) {
fn(t, tc)
})
}
}

0 comments on commit d2274df

Please sign in to comment.