From 638805ed62a1b4a61f6ac475dafbd242a2ce1d3f Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Wed, 19 Oct 2022 13:45:43 +0800 Subject: [PATCH 1/4] feat(PreparedStmtDB): support reset --- prepare_stmt.go | 6 ++++++ tests/prepared_stmt_test.go | 27 ++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97f..0d2ed15eb 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,12 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Close() + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2c..6c141851f 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,28 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + user := *GetUser("prepared_stmt_reset", Config{}) + tx.Create(&user) + + pdb.Mux.Lock() + if len(pdb.PreparedSQL) == 0 || len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.PreparedSQL) != 0 || len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} From 6ae2f6b00a9dc496707dcf3b9f762a09b7b84695 Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Wed, 19 Oct 2022 14:21:07 +0800 Subject: [PATCH 2/4] fix: close all stmt --- prepare_stmt.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 0d2ed15eb..a59552d25 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -45,7 +45,12 @@ func (db *PreparedStmtDB) Close() { } func (db *PreparedStmtDB) Reset() { - db.Close() + db.Mux.Lock() + defer db.Mux.Unlock() + for _, stmt := range db.Stmts { + go stmt.Close() + } + db.PreparedSQL = make([]string, 0, 100) db.Stmts = map[string](*Stmt){} } From 678adb3e2f0ab54fc8b001dc379247ceb26d559a Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Wed, 19 Oct 2022 14:27:35 +0800 Subject: [PATCH 3/4] test: fix test --- tests/prepared_stmt_test.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 6c141851f..64baa01be 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -171,16 +171,17 @@ func TestPreparedStmtInTransaction(t *testing.T) { func TestPreparedStmtReset(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) if !ok { t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") } - user := *GetUser("prepared_stmt_reset", Config{}) - tx.Create(&user) - pdb.Mux.Lock() - if len(pdb.PreparedSQL) == 0 || len(pdb.Stmts) == 0 { + if len(pdb.Stmts) == 0 { pdb.Mux.Unlock() t.Fatalf("prepared stmt can not be empty") } @@ -189,7 +190,7 @@ func TestPreparedStmtReset(t *testing.T) { pdb.Reset() pdb.Mux.Lock() defer pdb.Mux.Unlock() - if len(pdb.PreparedSQL) != 0 || len(pdb.Stmts) != 0 { + if len(pdb.Stmts) != 0 { t.Fatalf("prepared stmt should be empty") } } From d679d620ca095a19873b0b11db5c46f845eb9e5c Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Wed, 19 Oct 2022 14:44:14 +0800 Subject: [PATCH 4/4] fix: delete one by one --- prepare_stmt.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index a59552d25..7591e5331 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -47,7 +47,8 @@ func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Reset() { db.Mux.Lock() defer db.Mux.Unlock() - for _, stmt := range db.Stmts { + for query, stmt := range db.Stmts { + delete(db.Stmts, query) go stmt.Close() }