diff --git a/callbacks.go b/callbacks.go index c060ea709..f835e5049 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "errors" "fmt" "reflect" @@ -15,12 +16,13 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + "transaction": {db: db}, }, } } @@ -72,6 +74,29 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } +func (cs *callbacks) Transaction() *processor { + return cs.processors["transaction"] +} + +func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { + var err error + + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + _ = tx.AddError(err) + } + + return tx +} + func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { diff --git a/finisher_api.go b/finisher_api.go index 7a3f27bae..3e406c1cc 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -619,27 +619,13 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions - err error ) if len(opts) > 0 { opt = opts[0] } - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - tx.AddError(err) - } - - return tx + return tx.callbacks.Transaction().Begin(tx, opt) } // Commit commit a transaction