Skip to content

Commit

Permalink
fix: fix issue with has-many join and pointer fields (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreydwalter committed Apr 29, 2024
1 parent b005dc2 commit 93d0cdd
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
110 changes: 110 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ func TestDB(t *testing.T) {
{testFKViolation},
{testWithForeignKeysAndRules},
{testWithForeignKeys},
{testWithForeignKeysHasMany},
{testWithPointerForeignKeysHasMany},
{testInterfaceAny},
{testInterfaceJSON},
{testScanRawMessage},
Expand Down Expand Up @@ -1043,6 +1045,114 @@ func testWithForeignKeys(t *testing.T, db *bun.DB) {
require.Equal(t, d.User.Name, "root")
}

func testWithForeignKeysHasMany(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk"`
DeckID int
Name string
}
type Deck struct {
ID int `bun:",pk"`
Users []*User `bun:"rel:has-many,join:id=deck_id"`
}

if db.Dialect().Name() == dialect.SQLite {
_, err := db.Exec("PRAGMA foreign_keys = ON;")
require.NoError(t, err)
}

for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
}

mustResetModel(t, ctx, db, (*User)(nil))
_, err := db.NewCreateTable().
Model((*Deck)(nil)).
IfNotExists().
WithForeignKeys().
Exec(ctx)
require.NoError(t, err)
mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil))

deckID := 1
deck := Deck{ID: deckID}
_, err = db.NewInsert().Model(&deck).Exec(ctx)
require.NoError(t, err)

userID1 := 1
userID2 := 2
users := []*User{
{ID: userID1, DeckID: deckID, Name: "user 1"},
{ID: userID2, DeckID: deckID, Name: "user 2"},
}

res, err := db.NewInsert().Model(&users).Exec(ctx)
require.NoError(t, err)

affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(2), affected)

err = db.NewSelect().Model(&deck).Relation("Users").Scan(ctx)
require.NoError(t, err)
require.Len(t, deck.Users, 2)
}

func testWithPointerForeignKeysHasMany(t *testing.T, db *bun.DB) {
type User struct {
ID *int `bun:",pk"`
DeckID *int
Name string
}
type Deck struct {
ID *int `bun:",pk"`
Users []*User `bun:"rel:has-many,join:id=deck_id"`
}

if db.Dialect().Name() == dialect.SQLite {
_, err := db.Exec("PRAGMA foreign_keys = ON;")
require.NoError(t, err)
}

for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
}

mustResetModel(t, ctx, db, (*User)(nil))
_, err := db.NewCreateTable().
Model((*Deck)(nil)).
IfNotExists().
WithForeignKeys().
Exec(ctx)
require.NoError(t, err)
mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil))

deckID := 1
deck := Deck{ID: &deckID}
_, err = db.NewInsert().Model(&deck).Exec(ctx)
require.NoError(t, err)

userID1 := 1
userID2 := 2
users := []*User{
{ID: &userID1, DeckID: &deckID, Name: "user 1"},
{ID: &userID2, DeckID: &deckID, Name: "user 2"},
}

res, err := db.NewInsert().Model(&users).Exec(ctx)
require.NoError(t, err)

affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(2), affected)

err = db.NewSelect().Model(&deck).Relation("Users").Scan(ctx)
require.NoError(t, err)
require.Len(t, deck.Users, 2)
}

func testInterfaceAny(t *testing.T, db *bun.DB) {
switch db.Dialect().Name() {
case dialect.MySQL:
Expand Down
22 changes: 20 additions & 2 deletions model_table_has_many.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (m *hasManyModel) Scan(src interface{}) error {

for _, f := range m.rel.JoinFields {
if f.Name == field.Name {
m.structKey = append(m.structKey, field.Value(m.strct).Interface())
m.structKey = append(m.structKey, getFieldValue(field.Value(m.strct)))
break
}
}
Expand All @@ -103,6 +103,7 @@ func (m *hasManyModel) Scan(src interface{}) error {
}

func (m *hasManyModel) parkStruct() error {

baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
if !ok {
return fmt.Errorf(
Expand Down Expand Up @@ -143,7 +144,24 @@ func baseValues(model TableModel, fields []*schema.Field) map[internal.MapKey][]

func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} {
for _, f := range fields {
key = append(key, f.Value(strct).Interface())
key = append(key, getFieldValue(f.Value(strct)))
}
return key
}

// getFieldValue extracts the value from a reflect.Value, handling pointer types appropriately.
func getFieldValue(fieldValue reflect.Value) interface{} {
var keyValue interface{}

if fieldValue.Kind() == reflect.Ptr {
if !fieldValue.IsNil() {
keyValue = fieldValue.Elem().Interface()
} else {
keyValue = nil
}
} else {
keyValue = fieldValue.Interface()
}

return keyValue
}

0 comments on commit 93d0cdd

Please sign in to comment.