diff --git a/cursor_test.go b/cursor_test.go index 08ef7755..c211245b 100644 --- a/cursor_test.go +++ b/cursor_test.go @@ -304,3 +304,89 @@ func TestScanMulti_fieldsError(t *testing.T) { assert.Equal(t, err, scanMulti(cur, keyField, keyType, cols)) cur.AssertExpectations(t) } + +func TestScanMulti_multipleTimes(t *testing.T) { + var ( + users = make([][]User, 6) + cur = &testCursor{} + keyField = "id" + keyType = reflect.TypeOf(0) + cols = map[interface{}][]slice{ + 10: {NewCollection(&users[0]), NewCollection(&users[1])}, + 11: {NewCollection(&users[2])}, + 12: {NewCollection(&users[3]), NewCollection(&users[4])}, + 13: {NewCollection(&users[5])}, + } + now = Now() + ) + + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() + + cur.On("Next").Return(true).Twice() + cur.MockScan(10, "Del Piero", nil, now, nil).Times(3) + cur.MockScan(11, "Nedved", 46, now, now).Twice() + cur.On("Next").Return(false).Once() + + assert.Nil(t, scanMulti(cur, keyField, keyType, cols)) + assert.Len(t, users[0], 1) + assert.Equal(t, User{ + ID: 10, + Name: "Del Piero", + CreatedAt: now, + }, users[0][0]) + assert.Len(t, users[1], 1) + assert.Equal(t, User{ + ID: 10, + Name: "Del Piero", + CreatedAt: now, + }, users[1][0]) + assert.Len(t, users[2], 1) + assert.Equal(t, User{ + ID: 11, + Name: "Nedved", + Age: 46, + CreatedAt: now, + UpdatedAt: now, + }, users[2][0]) + + cur.AssertExpectations(t) + + // Continue with a new cursor but the same cols -> works only if the ids in + // the subsequent calls did not occur yet. + cur = &testCursor{} + + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() + + cur.On("Next").Return(true).Twice() + cur.MockScan(12, "Linus Torvalds", 52, now, nil).Times(3) + cur.MockScan(13, "Tim Cook", 61, now, now).Twice() + cur.On("Next").Return(false).Once() + + assert.Nil(t, scanMulti(cur, keyField, keyType, cols)) + assert.Len(t, users[3], 1) + assert.Equal(t, User{ + ID: 12, + Name: "Linus Torvalds", + Age: 52, + CreatedAt: now, + }, users[3][0]) + assert.Len(t, users[4], 1) + assert.Equal(t, User{ + ID: 12, + Age: 52, + Name: "Linus Torvalds", + CreatedAt: now, + }, users[4][0]) + assert.Len(t, users[5], 1) + assert.Equal(t, User{ + ID: 13, + Name: "Tim Cook", + Age: 61, + CreatedAt: now, + UpdatedAt: now, + }, users[5][0]) + + cur.AssertExpectations(t) +} diff --git a/repository.go b/repository.go index f4a4102d..437967fb 100644 --- a/repository.go +++ b/repository.go @@ -1030,25 +1030,49 @@ func (r repository) preload(cw contextWrapper, records slice, field string, quer path = strings.Split(field, ".") targets, table, keyField, keyType, ddata, loaded = r.mapPreloadTargets(records, path) ids = r.targetIDs(targets) - query = Build(table, append(queriers, In(keyField, ids...))...) + inClauseLength = 999 ) - if len(targets) == 0 || loaded && !bool(query.ReloadQuery) { - return nil - } + // Create separate queries if the amount of ids is more than inClauseLength. + for { + if len(ids) == 0 { + break + } - var ( - cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false)) - ) + // necessary check to avoid slicing beyond + // slice capacity + if len(ids) < inClauseLength { + inClauseLength = len(ids) + } - if err != nil { - return err - } + idsChunk := ids[0:inClauseLength] + ids = ids[inClauseLength:] - scanFinish := r.instrumenter.Observe(cw.ctx, "rel-scan-multi", "scanning all records to multiple targets") - defer scanFinish(nil) + query := Build(table, append(queriers, In(keyField, idsChunk...))...) + if len(targets) == 0 || loaded && !bool(query.ReloadQuery) { + return nil + } - return scanMulti(cur, keyField, keyType, targets) + var ( + cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false)) + ) + + if err != nil { + return err + } + + scanFinish := r.instrumenter.Observe(cw.ctx, "rel-scan-multi", "scanning all records to multiple targets") + // Note: Calling scanMulti multiple times with the same targets works + // only if the cursor of each execution only contains a new set of keys. + // That is here the case as each select is with a unique set of ids. + err = scanMulti(cur, keyField, keyType, targets) + scanFinish(err) + if err != nil { + return err + } + } + + return nil } func (r repository) MustPreload(ctx context.Context, records interface{}, field string, queriers ...Querier) { diff --git a/repository_test.go b/repository_test.go index ed7b1e66..57693407 100644 --- a/repository_test.go +++ b/repository_test.go @@ -3,6 +3,7 @@ package rel import ( "context" "errors" + "fmt" "io" "testing" "time" @@ -3128,6 +3129,48 @@ func TestRepository_Preload_hasOne(t *testing.T) { cur.AssertExpectations(t) } +func TestRepository_Preload_splitSelects(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + users = make([]User, 1100) + cur = &testCursor{} + ) + + for i := range users { + id := i + 1 + users[i] = User{ + ID: id, + Name: fmt.Sprintf("name%v", id), + } + } + + // Use mock.Anything instead of the actual select, as the order is random and not predictable + // as they are retrieved from map-keys. + // -> This test can only test if two selects were made, but not how they look like exactly. + adapter.On("Query", mock.Anything).Return(cur, nil).Once() + + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.On("Scan", mock.Anything, mock.Anything).Return(nil).Once() + cur.On("Next").Return(false).Once() + + // Same here. + adapter.On("Query", mock.Anything).Return(cur, nil).Once() + + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.On("Scan", mock.Anything, mock.Anything).Return(nil).Once() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.Preload(context.TODO(), &users, "emails")) + + adapter.AssertExpectations(t) + cur.AssertExpectations(t) +} + func TestRepository_Preload_ptrHasOne(t *testing.T) { var ( adapter = &testAdapter{} @@ -3893,6 +3936,30 @@ func TestRepository_Preload_queryError(t *testing.T) { cur.AssertExpectations(t) } +func TestRepository_Preload_scanErrors(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + user = User{ID: 10} + address = Address{ID: 100, UserID: &user.ID} + cur = &testCursor{} + err = errors.New("an error") + expected *Address = nil + ) + + adapter.On("Query", From("user_addresses").Where(In("user_id", 10).AndNil("deleted_at"))).Return(cur, nil).Once() + + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.MockScan(address.ID, *address.UserID).Return(err).Once() + assert.ErrorIs(t, repo.Preload(context.TODO(), &user, "work_address"), err) + assert.Equal(t, expected, user.WorkAddress) + + adapter.AssertExpectations(t) + cur.AssertExpectations(t) +} + func TestRepository_MustPreload(t *testing.T) { var ( adapter = &testAdapter{}