From 64cf625daf05e96a3b85c4eb0ccff7de45b67988 Mon Sep 17 00:00:00 2001 From: erezrokah Date: Thu, 27 Oct 2022 15:24:27 +0300 Subject: [PATCH] fix: Only consider pointer to structs when checking for embedded fields --- funk_test.go | 8 +++++++- retrieve.go | 13 +++---------- retrieve_test.go | 11 +++++++++-- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/funk_test.go b/funk_test.go index d60f138..b8f8370 100644 --- a/funk_test.go +++ b/funk_test.go @@ -129,8 +129,14 @@ type EmbeddedStruct struct { EmbeddedField *string } -type RootStruct struct { +type RootStructPointer struct { *EmbeddedStruct RootField *string } + +type RootStructNotPointer struct { + EmbeddedStruct + + RootField *string +} diff --git a/retrieve.go b/retrieve.go index e523d0a..818dc08 100644 --- a/retrieve.go +++ b/retrieve.go @@ -110,14 +110,11 @@ func isNilIndirection(v reflect.Value, name string) bool { vType := v.Type() for i := 0; i < vType.NumField(); i++ { field := vType.Field(i) - if !isEmbeddedStructField(field) { + if !isEmbeddedStructPointerField(field) { return false } - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = field.Type.Elem() - } + fieldType := field.Type.Elem() _, found := fieldType.FieldByName(name) if found { @@ -128,14 +125,10 @@ func isNilIndirection(v reflect.Value, name string) bool { return false } -func isEmbeddedStructField(field reflect.StructField) bool { +func isEmbeddedStructPointerField(field reflect.StructField) bool { if !field.Anonymous { return false } - if field.Type.Kind() == reflect.Struct { - return true - } - return field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct } diff --git a/retrieve_test.go b/retrieve_test.go index b5f4d4e..7222528 100644 --- a/retrieve_test.go +++ b/retrieve_test.go @@ -102,10 +102,17 @@ func TestGetOrElse(t *testing.T) { }) } -func TestEmbeddedStruct(t *testing.T) { +func TestEmbeddedStructPointer(t *testing.T) { is := assert.New(t) - root := RootStruct{} + root := RootStructPointer{} is.Equal(Get(root, "EmbeddedField"), nil) is.Equal(Get(root, "EmbeddedStruct.EmbeddedField"), nil) } + +func TestEmbeddedStructNotPointer(t *testing.T) { + is := assert.New(t) + + root := RootStructNotPointer{} + is.Equal(Get(root, "EmbeddedField"), nil) +}