Skip to content

Commit

Permalink
catch panic within transaction to complete rollback (fixes #748)
Browse files Browse the repository at this point in the history
  • Loading branch information
sio4 committed Sep 24, 2022
1 parent 2cbff73 commit 3f5bb8a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
21 changes: 19 additions & 2 deletions connection.go
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/internal/randx"
"github.com/gobuffalo/pop/v6/logging"
)

// Connections contains all available connections
Expand Down Expand Up @@ -151,21 +152,37 @@ func (c *Connection) Close() error {
// returns an error then the transaction will be rolled back, otherwise the transaction
// will automatically commit at the end.
func (c *Connection) Transaction(fn func(tx *Connection) error) error {
return c.Dialect.Lock(func() error {
return c.Dialect.Lock(func() (err error) {
var dberr error

log(logging.SQL, "--- BEGIN Transaction ---")
cn, err := c.NewTransaction()
if err != nil {
return err
}

defer func() {
if ex := recover(); ex != nil {
log(logging.SQL, "--- ROLLBACK Transaction (inner function panic) ---")
dberr = cn.TX.Rollback()
if dberr != nil {
err = fmt.Errorf("database error while inner panic rollback: %w", dberr)
}
err = fmt.Errorf("transaction was rolled back due to inner panic")
}
}()

err = fn(cn)
if err != nil {
log(logging.SQL, "--- ROLLBACK Transaction ---")
dberr = cn.TX.Rollback()
} else {
log(logging.SQL, "--- END Transaction ---")
dberr = cn.TX.Commit()
}

if dberr != nil {
return fmt.Errorf("error committing or rolling back transaction: %w", dberr)
return fmt.Errorf("database error on committing or rolling back transaction: %w", dberr)
}

return err
Expand Down
35 changes: 34 additions & 1 deletion connection_test.go
Expand Up @@ -5,6 +5,7 @@ package pop

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -55,7 +56,7 @@ func Test_Connection_Open_BadDriver(t *testing.T) {
r.Error(err)
}

func Test_Connection_Transaction(t *testing.T) {
func Test_Connection_NewTransaction(t *testing.T) {
r := require.New(t)
ctx := context.WithValue(context.Background(), "test", "test")

Expand Down Expand Up @@ -97,3 +98,35 @@ func Test_Connection_Transaction(t *testing.T) {
r.NoError(tx.TX.Rollback())
})
}

func Test_Connection_Transaction(t *testing.T) {
r := require.New(t)

c, err := NewConnection(&ConnectionDetails{
URL: "sqlite://file::memory:?_fk=true",
})
r.NoError(err)
r.NoError(c.Open())

t.Run("Success", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
return nil
})
r.NoError(err)
})

t.Run("Failed", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
return fmt.Errorf("failed")
})
r.ErrorContains(err, "failed")
})

t.Run("Panic", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
panic("inner function panic")
})
r.ErrorContains(err, "panic")
r.ErrorContains(err, "rolled back")
})
}

0 comments on commit 3f5bb8a

Please sign in to comment.