From 342310fba4fc56decf3d417925326db483734d7e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 21 Nov 2022 10:49:27 +0800 Subject: [PATCH] fix(FindInBatches): throw err if pk not exists (#5868) --- finisher_api.go | 11 ++++++++--- tests/query_test.go | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5516c0a14..cc07a126e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -231,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -514,8 +518,9 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { } // Pluck queries a single column from a model, returning in the slice dest. E.g.: -// var ages []int64 -// db.Model(&users).Pluck("age", &ages) +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { diff --git a/tests/query_test.go b/tests/query_test.go index eccf0133d..fa8f09e8b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -408,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) { if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } } func TestFillSmallerStruct(t *testing.T) {