Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

catch panic within transaction to complete rollback (fixes #748) #775

Merged
merged 2 commits into from Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 19 additions & 2 deletions connection.go
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/gobuffalo/pop/v6/internal/defaults"
"github.com/gobuffalo/pop/v6/internal/randx"
"github.com/gobuffalo/pop/v6/logging"
)

// Connections contains all available connections
Expand Down Expand Up @@ -151,21 +152,37 @@ func (c *Connection) Close() error {
// returns an error then the transaction will be rolled back, otherwise the transaction
// will automatically commit at the end.
func (c *Connection) Transaction(fn func(tx *Connection) error) error {
return c.Dialect.Lock(func() error {
return c.Dialect.Lock(func() (err error) {
var dberr error

log(logging.SQL, "--- BEGIN Transaction ---")
sio4 marked this conversation as resolved.
Show resolved Hide resolved
cn, err := c.NewTransaction()
if err != nil {
return err
}

defer func() {
if ex := recover(); ex != nil {
log(logging.SQL, "--- ROLLBACK Transaction (inner function panic) ---")
dberr = cn.TX.Rollback()
if dberr != nil {
err = fmt.Errorf("database error while inner panic rollback: %w", dberr)
}
err = fmt.Errorf("transaction was rolled back due to inner panic")
}
}()
sio4 marked this conversation as resolved.
Show resolved Hide resolved

err = fn(cn)
if err != nil {
log(logging.SQL, "--- ROLLBACK Transaction ---")
dberr = cn.TX.Rollback()
} else {
log(logging.SQL, "--- END Transaction ---")
dberr = cn.TX.Commit()
}

if dberr != nil {
return fmt.Errorf("error committing or rolling back transaction: %w", dberr)
return fmt.Errorf("database error on committing or rolling back transaction: %w", dberr)
}

return err
Expand Down
35 changes: 34 additions & 1 deletion connection_test.go
Expand Up @@ -5,6 +5,7 @@ package pop

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -55,7 +56,7 @@ func Test_Connection_Open_BadDriver(t *testing.T) {
r.Error(err)
}

func Test_Connection_Transaction(t *testing.T) {
func Test_Connection_NewTransaction(t *testing.T) {
r := require.New(t)
ctx := context.WithValue(context.Background(), "test", "test")

Expand Down Expand Up @@ -97,3 +98,35 @@ func Test_Connection_Transaction(t *testing.T) {
r.NoError(tx.TX.Rollback())
})
}

func Test_Connection_Transaction(t *testing.T) {
r := require.New(t)

c, err := NewConnection(&ConnectionDetails{
URL: "sqlite://file::memory:?_fk=true",
})
r.NoError(err)
r.NoError(c.Open())

t.Run("Success", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
return nil
})
r.NoError(err)
})

t.Run("Failed", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
return fmt.Errorf("failed")
})
r.ErrorContains(err, "failed")
})

t.Run("Panic", func(t *testing.T) {
err = c.Transaction(func(c *Connection) error {
panic("inner function panic")
})
r.ErrorContains(err, "panic")
r.ErrorContains(err, "rolled back")
})
}