Skip to content

Commit

Permalink
Merge pull request #161 from erezrokah/fix/nil_indirection
Browse files Browse the repository at this point in the history
fix: Only consider pointer to structs when checking for embedded fields
  • Loading branch information
thoas committed Dec 26, 2022
2 parents df86593 + 64cf625 commit df1ff15
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
8 changes: 7 additions & 1 deletion funk_test.go
Expand Up @@ -129,8 +129,14 @@ type EmbeddedStruct struct {
EmbeddedField *string
}

type RootStruct struct {
type RootStructPointer struct {
*EmbeddedStruct

RootField *string
}

type RootStructNotPointer struct {
EmbeddedStruct

RootField *string
}
13 changes: 3 additions & 10 deletions retrieve.go
Expand Up @@ -112,14 +112,11 @@ func isNilIndirection(v reflect.Value, name string) bool {
vType := v.Type()
for i := 0; i < vType.NumField(); i++ {
field := vType.Field(i)
if !isEmbeddedStructField(field) {
if !isEmbeddedStructPointerField(field) {
return false
}

fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = field.Type.Elem()
}
fieldType := field.Type.Elem()

_, found := fieldType.FieldByName(name)
if found {
Expand All @@ -130,14 +127,10 @@ func isNilIndirection(v reflect.Value, name string) bool {
return false
}

func isEmbeddedStructField(field reflect.StructField) bool {
func isEmbeddedStructPointerField(field reflect.StructField) bool {
if !field.Anonymous {
return false
}

if field.Type.Kind() == reflect.Struct {
return true
}

return field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct
}
11 changes: 9 additions & 2 deletions retrieve_test.go
Expand Up @@ -117,10 +117,17 @@ func TestGetOrElse(t *testing.T) {
})
}

func TestEmbeddedStruct(t *testing.T) {
func TestEmbeddedStructPointer(t *testing.T) {
is := assert.New(t)

root := RootStruct{}
root := RootStructPointer{}
is.Equal(Get(root, "EmbeddedField"), nil)
is.Equal(Get(root, "EmbeddedStruct.EmbeddedField"), nil)
}

func TestEmbeddedStructNotPointer(t *testing.T) {
is := assert.New(t)

root := RootStructNotPointer{}
is.Equal(Get(root, "EmbeddedField"), nil)
}

0 comments on commit df1ff15

Please sign in to comment.