Skip to content

Commit

Permalink
Split preloading-IN query into multiple queries (#283) (#285)
Browse files Browse the repository at this point in the history
* Split preloading-IN query into multiple queries (#283)

* Cleanup code, pass err to scanFinish (#283)

* Add test for calling scanMulti multiple times with the same cols (#283)

* Add test for Preload (#283)
  • Loading branch information
aligator committed May 27, 2022
1 parent d4f89cf commit ae7d739
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 13 deletions.
86 changes: 86 additions & 0 deletions cursor_test.go
Expand Up @@ -304,3 +304,89 @@ func TestScanMulti_fieldsError(t *testing.T) {
assert.Equal(t, err, scanMulti(cur, keyField, keyType, cols))
cur.AssertExpectations(t)
}

func TestScanMulti_multipleTimes(t *testing.T) {
var (
users = make([][]User, 6)
cur = &testCursor{}
keyField = "id"
keyType = reflect.TypeOf(0)
cols = map[interface{}][]slice{
10: {NewCollection(&users[0]), NewCollection(&users[1])},
11: {NewCollection(&users[2])},
12: {NewCollection(&users[3]), NewCollection(&users[4])},
13: {NewCollection(&users[5])},
}
now = Now()
)

cur.On("Close").Return(nil).Once()
cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once()

cur.On("Next").Return(true).Twice()
cur.MockScan(10, "Del Piero", nil, now, nil).Times(3)
cur.MockScan(11, "Nedved", 46, now, now).Twice()
cur.On("Next").Return(false).Once()

assert.Nil(t, scanMulti(cur, keyField, keyType, cols))
assert.Len(t, users[0], 1)
assert.Equal(t, User{
ID: 10,
Name: "Del Piero",
CreatedAt: now,
}, users[0][0])
assert.Len(t, users[1], 1)
assert.Equal(t, User{
ID: 10,
Name: "Del Piero",
CreatedAt: now,
}, users[1][0])
assert.Len(t, users[2], 1)
assert.Equal(t, User{
ID: 11,
Name: "Nedved",
Age: 46,
CreatedAt: now,
UpdatedAt: now,
}, users[2][0])

cur.AssertExpectations(t)

// Continue with a new cursor but the same cols -> works only if the ids in
// the subsequent calls did not occur yet.
cur = &testCursor{}

cur.On("Close").Return(nil).Once()
cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once()

cur.On("Next").Return(true).Twice()
cur.MockScan(12, "Linus Torvalds", 52, now, nil).Times(3)
cur.MockScan(13, "Tim Cook", 61, now, now).Twice()
cur.On("Next").Return(false).Once()

assert.Nil(t, scanMulti(cur, keyField, keyType, cols))
assert.Len(t, users[3], 1)
assert.Equal(t, User{
ID: 12,
Name: "Linus Torvalds",
Age: 52,
CreatedAt: now,
}, users[3][0])
assert.Len(t, users[4], 1)
assert.Equal(t, User{
ID: 12,
Age: 52,
Name: "Linus Torvalds",
CreatedAt: now,
}, users[4][0])
assert.Len(t, users[5], 1)
assert.Equal(t, User{
ID: 13,
Name: "Tim Cook",
Age: 61,
CreatedAt: now,
UpdatedAt: now,
}, users[5][0])

cur.AssertExpectations(t)
}
50 changes: 37 additions & 13 deletions repository.go
Expand Up @@ -1030,25 +1030,49 @@ func (r repository) preload(cw contextWrapper, records slice, field string, quer
path = strings.Split(field, ".")
targets, table, keyField, keyType, ddata, loaded = r.mapPreloadTargets(records, path)
ids = r.targetIDs(targets)
query = Build(table, append(queriers, In(keyField, ids...))...)
inClauseLength = 999
)

if len(targets) == 0 || loaded && !bool(query.ReloadQuery) {
return nil
}
// Create separate queries if the amount of ids is more than inClauseLength.
for {
if len(ids) == 0 {
break
}

var (
cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false))
)
// necessary check to avoid slicing beyond
// slice capacity
if len(ids) < inClauseLength {
inClauseLength = len(ids)
}

if err != nil {
return err
}
idsChunk := ids[0:inClauseLength]
ids = ids[inClauseLength:]

scanFinish := r.instrumenter.Observe(cw.ctx, "rel-scan-multi", "scanning all records to multiple targets")
defer scanFinish(nil)
query := Build(table, append(queriers, In(keyField, idsChunk...))...)
if len(targets) == 0 || loaded && !bool(query.ReloadQuery) {
return nil
}

return scanMulti(cur, keyField, keyType, targets)
var (
cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false))
)

if err != nil {
return err
}

scanFinish := r.instrumenter.Observe(cw.ctx, "rel-scan-multi", "scanning all records to multiple targets")
// Note: Calling scanMulti multiple times with the same targets works
// only if the cursor of each execution only contains a new set of keys.
// That is here the case as each select is with a unique set of ids.
err = scanMulti(cur, keyField, keyType, targets)
scanFinish(err)
if err != nil {
return err
}
}

return nil
}

func (r repository) MustPreload(ctx context.Context, records interface{}, field string, queriers ...Querier) {
Expand Down
67 changes: 67 additions & 0 deletions repository_test.go
Expand Up @@ -3,6 +3,7 @@ package rel
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
Expand Down Expand Up @@ -3128,6 +3129,48 @@ func TestRepository_Preload_hasOne(t *testing.T) {
cur.AssertExpectations(t)
}

func TestRepository_Preload_splitSelects(t *testing.T) {
var (
adapter = &testAdapter{}
repo = New(adapter)
users = make([]User, 1100)
cur = &testCursor{}
)

for i := range users {
id := i + 1
users[i] = User{
ID: id,
Name: fmt.Sprintf("name%v", id),
}
}

// Use mock.Anything instead of the actual select, as the order is random and not predictable
// as they are retrieved from map-keys.
// -> This test can only test if two selects were made, but not how they look like exactly.
adapter.On("Query", mock.Anything).Return(cur, nil).Once()

cur.On("Close").Return(nil).Once()
cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once()
cur.On("Next").Return(true).Once()
cur.On("Scan", mock.Anything, mock.Anything).Return(nil).Once()
cur.On("Next").Return(false).Once()

// Same here.
adapter.On("Query", mock.Anything).Return(cur, nil).Once()

cur.On("Close").Return(nil).Once()
cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once()
cur.On("Next").Return(true).Once()
cur.On("Scan", mock.Anything, mock.Anything).Return(nil).Once()
cur.On("Next").Return(false).Once()

assert.Nil(t, repo.Preload(context.TODO(), &users, "emails"))

adapter.AssertExpectations(t)
cur.AssertExpectations(t)
}

func TestRepository_Preload_ptrHasOne(t *testing.T) {
var (
adapter = &testAdapter{}
Expand Down Expand Up @@ -3893,6 +3936,30 @@ func TestRepository_Preload_queryError(t *testing.T) {
cur.AssertExpectations(t)
}

func TestRepository_Preload_scanErrors(t *testing.T) {
var (
adapter = &testAdapter{}
repo = New(adapter)
user = User{ID: 10}
address = Address{ID: 100, UserID: &user.ID}
cur = &testCursor{}
err = errors.New("an error")
expected *Address = nil
)

adapter.On("Query", From("user_addresses").Where(In("user_id", 10).AndNil("deleted_at"))).Return(cur, nil).Once()

cur.On("Close").Return(nil).Once()
cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once()
cur.On("Next").Return(true).Once()
cur.MockScan(address.ID, *address.UserID).Return(err).Once()
assert.ErrorIs(t, repo.Preload(context.TODO(), &user, "work_address"), err)
assert.Equal(t, expected, user.WorkAddress)

adapter.AssertExpectations(t)
cur.AssertExpectations(t)
}

func TestRepository_MustPreload(t *testing.T) {
var (
adapter = &testAdapter{}
Expand Down

0 comments on commit ae7d739

Please sign in to comment.