diff --git a/finisher_api.go b/finisher_api.go index 5e4c3c5a5..fb765d33d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -181,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch int ) + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + totalSize = limit.Limit + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -196,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) if result.Statement.Schema.PrioritizedPrimaryField == nil { diff --git a/tests/query_test.go b/tests/query_test.go index af2b8d4b4..f66cf83a4 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithOffsetLimit(t *testing.T) { + users := []User{ + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + } + + DB.Create(&users) + + var ( + sub, results []User + lastBatch int + ) + + // offset limit + if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + results = append(results, sub...) + lastBatch = batch + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + if lastBatch != 3 { + t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) + } + + targetUsers := users[3:8] + for i := 0; i < len(targetUsers); i++ { + AssertEqual(t, results[i], targetUsers[i]) + } + + var sub1 []User + // limit < batchSize + if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub2 []User + // only offset + if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 7 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub3 []User + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } +} + func TestFindInBatchesWithError(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql")