Skip to content

Commit

Permalink
fix: update setupValuerAndSetter to use default values when pointer t…
Browse files Browse the repository at this point in the history
…ypes are nil
  • Loading branch information
waleed.masoom committed Apr 4, 2024
1 parent 1b48aa0 commit d66da4e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
41 changes: 40 additions & 1 deletion scan.go
Expand Up @@ -66,13 +66,21 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
fieldsWithValueMap := make(map[string]bool)
for idx, field := range fields {
if field == nil {
continue
}

if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
fieldIsEmbeddedPointerTypeStruct := len(field.BindNames) > 1 && len(field.StructField.Index) > 0 && field.StructField.Index[0] < 0
fieldValue := reflect.ValueOf(values[idx]).Elem()
if !fieldIsEmbeddedPointerTypeStruct && fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
db.AddError(field.Set(db.Statement.Context, reflectValue, field.DefaultValueInterface))
} else {
fieldsWithValueMap[field.BindName()] = fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil()
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
}
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
Expand Down Expand Up @@ -109,6 +117,37 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
// release data to pool
field.NewValuePool.Put(values[idx])
}

if dest := reflect.Indirect(db.Statement.ReflectValue); len(db.Statement.Clauses) == 0 && dest.Kind() == reflect.Struct {
resetEmbeddedPointerTypeStruct(dest, db.Statement.Schema, fieldsWithValueMap)
}
}

func resetEmbeddedPointerTypeStruct(dest reflect.Value, schema *schema.Schema, fieldsWithValueMap map[string]bool) {
for i := 0; i < dest.NumField(); i++ {
field := schema.ParseField(dest.Type().Field(i))
if field.EmbeddedSchema != nil && field.FieldType.Kind() == reflect.Ptr {
if !wasValueScannedIntoEmbeddedStruct(field, fieldsWithValueMap) && dest.Field(i).Kind() == reflect.Ptr && !dest.Field(i).IsNil() {
dest.Field(i).Set(reflect.Zero(dest.Field(i).Type()))
}
}
}
}

func wasValueScannedIntoEmbeddedStruct(field *schema.Field, fieldsWithValueMap map[string]bool) bool {
if fieldsWithValueMap[field.BindName()] {
return true
}

if field.EmbeddedSchema != nil {
for _, embeddedField := range field.EmbeddedSchema.Fields {
if wasValueScannedIntoEmbeddedStruct(embeddedField, fieldsWithValueMap) {
return true
}
}
}

return false
}

// ScanMode scan data mode
Expand Down
2 changes: 1 addition & 1 deletion tests/go.mod
Expand Up @@ -17,7 +17,7 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.8.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
Expand Down
104 changes: 101 additions & 3 deletions tests/scan_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"sort"
"strings"
"testing"
"time"

"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
Expand Down Expand Up @@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) {

rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
if err != nil {
t.Errorf("Not error should happen, got %v", err)
t.Errorf("No error should happen, got %v", err)
}

type Result struct {
Expand All @@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) {
})

if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results")
t.Errorf("Should find expected results, got %+v", results)
}

var ages int
Expand All @@ -158,7 +159,104 @@ func TestScanRows(t *testing.T) {

var name string
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
}
}

func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
DB.Save(&User{})

rows, err := DB.Table("users").
Select(`
NULL AS bool_field,
NULL AS int_field,
NULL AS int8_field,
NULL AS int16_field,
NULL AS int32_field,
NULL AS int64_field,
NULL AS uint_field,
NULL AS uint8_field,
NULL AS uint16_field,
NULL AS uint32_field,
NULL AS uint64_field,
NULL AS float32_field,
NULL AS float64_field,
NULL AS string_field,
NULL AS time_field,
NULL AS time_ptr_field,
NULL AS embedded_int_field,
NULL AS nested_embedded_int_field,
NULL AS embedded_ptr_int_field
`).Rows()
if err != nil {
t.Errorf("No error should happen, got %v", err)
}

type NestedEmbeddedStruct struct {
NestedEmbeddedIntField int
}

type EmbeddedStruct struct {
EmbeddedIntField int
NestedEmbeddedStruct `gorm:"embedded"`
}

type EmbeddedPtrStruct struct {
EmbeddedPtrIntField int
*NestedEmbeddedStruct `gorm:"embedded"`
}

type Result struct {
BoolField bool
IntField int
Int8Field int8
Int16Field int16
Int32Field int32
Int64Field int64
UIntField uint
UInt8Field uint8
UInt16Field uint16
UInt32Field uint32
UInt64Field uint64
Float32Field float32
Float64Field float64
StringField string
TimeField time.Time
TimePtrField *time.Time
EmbeddedStruct `gorm:"embedded"`
*EmbeddedPtrStruct `gorm:"embedded"`
}

currTime := time.Now()
result := Result{
BoolField: true,
IntField: 1,
Int8Field: 1,
Int16Field: 1,
Int32Field: 1,
Int64Field: 1,
UIntField: 1,
UInt8Field: 1,
UInt16Field: 1,
UInt32Field: 1,
UInt64Field: 1,
Float32Field: 1.1,
Float64Field: 1.1,
StringField: "hello",
TimeField: currTime,
TimePtrField: &currTime,
EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1}},
EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1}},
}

for rows.Next() {
if err := DB.ScanRows(rows, &result); err != nil {
t.Errorf("should get no error, but got %v", err)
}
}

if !reflect.DeepEqual(result, Result{}) {
t.Errorf("Should find zero values in struct fields, got %+v", result)
}
}

Expand Down

0 comments on commit d66da4e

Please sign in to comment.