Skip to content

Commit

Permalink
Handle field set value error
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 23, 2022
1 parent a7b3b59 commit f92e674
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 33 deletions.
14 changes: 7 additions & 7 deletions callbacks/associations.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
ref.ForeignKey.Set(db.Statement.Context, f, fv)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
Expand Down Expand Up @@ -193,9 +193,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
ref.ForeignKey.Set(db.Statement.Context, elem, pv)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}

Expand Down Expand Up @@ -261,12 +261,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
}
}
joins = reflect.Append(joins, joinValue)
Expand Down
18 changes: 9 additions & 9 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func Create(config *Config) func(db *gorm.DB) {

_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
Expand All @@ -133,15 +133,15 @@ func Create(config *Config) func(db *gorm.DB) {
}

if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
Expand Down Expand Up @@ -227,13 +227,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
field.Set(stmt.Context, rv, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.Context, rv, curTime)
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(stmt.Context, rv, curTime)
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
Expand Down Expand Up @@ -267,13 +267,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.Context, stmt.ReflectValue, curTime)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(stmt.Context, stmt.ReflectValue, curTime)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
Expand Down
14 changes: 7 additions & 7 deletions callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,17 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
Expand All @@ -158,12 +158,12 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(tx.Statement.Context, data, elem.Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values
for idx, field := range fields {
if field != nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil {
field.Set(db.Statement.Context, reflectValue, values[idx])
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else {
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
Expand All @@ -79,7 +79,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values

relValue.Set(reflect.New(relValue.Type().Elem()))
}
joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
}

// release data to pool
Expand Down
5 changes: 3 additions & 2 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)

Expand Down Expand Up @@ -567,8 +568,8 @@ func (field *Field) setupValuerAndSetter() {
if v, err = valuer.Value(); err == nil {
err = setter(ctx, value, v)
}
} else {
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
} else if _, ok := v.(clause.Expr); !ok {
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
}
}

Expand Down
8 changes: 4 additions & 4 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .

switch destValue.Kind() {
case reflect.Struct:
field.Set(stmt.Context, destValue, value)
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
Expand All @@ -572,18 +572,18 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
if !stmt.ReflectValue.CanAddr() {
stmt.AddError(ErrInvalidValue)
return
}

field.Set(stmt.Context, stmt.ReflectValue, value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
Expand Down
2 changes: 1 addition & 1 deletion tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.12 // indirect
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect
golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect
gorm.io/driver/mysql v1.3.2
gorm.io/driver/postgres v1.3.1
gorm.io/driver/sqlite v1.3.1
Expand Down

0 comments on commit f92e674

Please sign in to comment.