Skip to content

Commit

Permalink
fix: prepare deadlock (#5568)
Browse files Browse the repository at this point in the history
* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
  • Loading branch information
a631807682 and jinzhu committed Sep 30, 2022
1 parent a3cc6c6 commit 0b7113b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 12 deletions.
2 changes: 1 addition & 1 deletion gorm.go
Expand Up @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {

preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Stmts: map[string](*Stmt){},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
Expand Down
54 changes: 43 additions & 11 deletions prepare_stmt.go
Expand Up @@ -9,10 +9,12 @@ import (
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}

type PreparedStmtDB struct {
Stmts map[string]Stmt
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
Expand Down Expand Up @@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, nil
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}

return *stmt, nil
}
db.Mux.RUnlock()

db.Mux.Lock()
defer db.Mux.Unlock()

// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
return stmt, nil
} else if ok {
go stmt.Close()
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}

return *stmt, nil
}

// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()

// prepare completed
defer close(cacheStmt.prepared)

// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}

return db.Stmts[query], err
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()

return cacheStmt, nil
}

func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
Expand Down
63 changes: 63 additions & 0 deletions tests/prepared_stmt_test.go
Expand Up @@ -2,6 +2,7 @@ package tests_test

import (
"context"
"sync"
"errors"
"testing"
"time"
Expand Down Expand Up @@ -90,6 +91,68 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit()
}

func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)

sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)

tx = tx.Session(&gorm.Session{PrepareStmt: true})

wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
user := User{Name: "jinzhu"}
tx.Create(&user)

var result User
tx.First(&result)
wg.Done()
}()
}
wg.Wait()

conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}

AssertEqual(t, sqlDB.Stats().InUse, 0)
}

func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)

sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)

tx = tx.Session(&gorm.Session{PrepareStmt: true})

wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()

conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}

func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}

Expand Down

0 comments on commit 0b7113b

Please sign in to comment.