diff --git a/connection.go b/connection.go index 932b4777..f51fd367 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ import ( "github.com/gobuffalo/pop/v6/internal/defaults" "github.com/gobuffalo/pop/v6/internal/randx" + "github.com/gobuffalo/pop/v6/logging" ) // Connections contains all available connections @@ -151,21 +152,37 @@ func (c *Connection) Close() error { // returns an error then the transaction will be rolled back, otherwise the transaction // will automatically commit at the end. func (c *Connection) Transaction(fn func(tx *Connection) error) error { - return c.Dialect.Lock(func() error { + return c.Dialect.Lock(func() (err error) { var dberr error + + log(logging.SQL, "--- BEGIN Transaction ---") cn, err := c.NewTransaction() if err != nil { return err } + + defer func() { + if ex := recover(); ex != nil { + log(logging.SQL, "--- ROLLBACK Transaction (inner function panic) ---") + dberr = cn.TX.Rollback() + if dberr != nil { + err = fmt.Errorf("database error while inner panic rollback: %w", dberr) + } + err = fmt.Errorf("transaction was rolled back due to inner panic") + } + }() + err = fn(cn) if err != nil { + log(logging.SQL, "--- ROLLBACK Transaction ---") dberr = cn.TX.Rollback() } else { + log(logging.SQL, "--- END Transaction ---") dberr = cn.TX.Commit() } if dberr != nil { - return fmt.Errorf("error committing or rolling back transaction: %w", dberr) + return fmt.Errorf("database error on committing or rolling back transaction: %w", dberr) } return err diff --git a/connection_test.go b/connection_test.go index a923a281..55c801ea 100644 --- a/connection_test.go +++ b/connection_test.go @@ -5,6 +5,7 @@ package pop import ( "context" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -55,7 +56,7 @@ func Test_Connection_Open_BadDriver(t *testing.T) { r.Error(err) } -func Test_Connection_Transaction(t *testing.T) { +func Test_Connection_NewTransaction(t *testing.T) { r := require.New(t) ctx := context.WithValue(context.Background(), "test", "test") @@ -97,3 +98,35 @@ func Test_Connection_Transaction(t *testing.T) { r.NoError(tx.TX.Rollback()) }) } + +func Test_Connection_Transaction(t *testing.T) { + r := require.New(t) + + c, err := NewConnection(&ConnectionDetails{ + URL: "sqlite://file::memory:?_fk=true", + }) + r.NoError(err) + r.NoError(c.Open()) + + t.Run("Success", func(t *testing.T) { + err = c.Transaction(func(c *Connection) error { + return nil + }) + r.NoError(err) + }) + + t.Run("Failed", func(t *testing.T) { + err = c.Transaction(func(c *Connection) error { + return fmt.Errorf("failed") + }) + r.ErrorContains(err, "failed") + }) + + t.Run("Panic", func(t *testing.T) { + err = c.Transaction(func(c *Connection) error { + panic("inner function panic") + }) + r.ErrorContains(err, "panic") + r.ErrorContains(err, "rolled back") + }) +}