Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prepare deadlock #5568

Merged
merged 8 commits into from Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gorm.go
Expand Up @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {

preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Stmts: map[string](*Stmt){},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
Expand Down
54 changes: 43 additions & 11 deletions prepare_stmt.go
Expand Up @@ -9,10 +9,12 @@ import (
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}

type PreparedStmtDB struct {
Stmts map[string]Stmt
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
Expand Down Expand Up @@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, nil
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}

return *stmt, nil
}
db.Mux.RUnlock()

db.Mux.Lock()
defer db.Mux.Unlock()

// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
return stmt, nil
} else if ok {
go stmt.Close()
a631807682 marked this conversation as resolved.
Show resolved Hide resolved
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}

return *stmt, nil
}

// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()

// prepare completed
defer close(cacheStmt.prepared)

// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
a631807682 marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}

return db.Stmts[query], err
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()

return cacheStmt, nil
}

func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
Expand Down
63 changes: 63 additions & 0 deletions tests/prepared_stmt_test.go
Expand Up @@ -2,6 +2,7 @@ package tests_test

import (
"context"
"sync"
"errors"
"testing"
"time"
Expand Down Expand Up @@ -90,6 +91,68 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit()
}

func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)

sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)

tx = tx.Session(&gorm.Session{PrepareStmt: true})

wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
user := User{Name: "jinzhu"}
tx.Create(&user)

var result User
tx.First(&result)
wg.Done()
}()
}
wg.Wait()

conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}

AssertEqual(t, sqlDB.Stats().InUse, 0)
}

func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)

sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)

tx = tx.Session(&gorm.Session{PrepareStmt: true})

wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()

conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}

func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}

Expand Down