diff --git a/.gitignore b/.gitignore index 74d9ca99..efb00cf9 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ testdata/migrations/schema.sql vendor/ .env .release-env +.vscode/ # test data cockroach-data/ diff --git a/connection.go b/connection.go index f8b81d31..ecb3e87c 100644 --- a/connection.go +++ b/connection.go @@ -2,6 +2,7 @@ package pop import ( "context" + "database/sql" "errors" "fmt" "sync/atomic" @@ -185,21 +186,26 @@ func (c *Connection) Rollback(fn func(tx *Connection)) error { // NewTransaction starts a new transaction on the connection func (c *Connection) NewTransaction() (*Connection, error) { + return c.NewTransactionContextOptions(c.Context(), nil) +} + +// NewTransactionContext starts a new transaction on the connection using the provided context +func (c *Connection) NewTransactionContext(ctx context.Context) (*Connection, error) { + return c.NewTransactionContextOptions(ctx, nil) +} + +// NewTransactionContextOptions starts a new transaction on the connection using the provided context and transaction options +func (c *Connection) NewTransactionContextOptions(ctx context.Context, options *sql.TxOptions) (*Connection, error) { var cn *Connection if c.TX == nil { - tx, err := c.Store.Transaction() + tx, err := c.Store.TransactionContextOptions(ctx, options) if err != nil { return cn, fmt.Errorf("couldn't start a new transaction: %w", err) } - var store store = tx - // Rewrap the store if it was a context store - if cs, ok := c.Store.(contextStore); ok { - store = contextStore{store: store, ctx: cs.ctx} - } cn = &Connection{ ID: randx.String(30), - Store: store, + Store: contextStore{store: tx, ctx: ctx}, Dialect: c.Dialect, TX: tx, } diff --git a/connection_test.go b/connection_test.go index dfdb9756..a923a281 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1,8 +1,10 @@ +//go:build sqlite // +build sqlite package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -52,3 +54,46 @@ func Test_Connection_Open_BadDriver(t *testing.T) { err = c.Open() r.Error(err) } + +func Test_Connection_Transaction(t *testing.T) { + r := require.New(t) + ctx := context.WithValue(context.Background(), "test", "test") + + c, err := NewConnection(&ConnectionDetails{ + URL: "sqlite://file::memory:?_fk=true", + }) + r.NoError(err) + r.NoError(c.Open()) + c = c.WithContext(ctx) + + t.Run("func=NewTransaction", func(t *testing.T) { + r := require.New(t) + tx, err := c.NewTransaction() + r.NoError(err) + + // has transaction and context + r.NotNil(tx.TX) + r.Nil(c.TX) + r.Equal(ctx, tx.Context()) + + // does not start a new transaction + ntx, err := tx.NewTransaction() + r.Equal(tx, ntx) + + r.NoError(tx.TX.Rollback()) + }) + + t.Run("func=NewTransactionContext", func(t *testing.T) { + r := require.New(t) + nctx := context.WithValue(ctx, "nested", "test") + tx, err := c.NewTransactionContext(nctx) + r.NoError(err) + + // has transaction and context + r.NotNil(tx.TX) + r.Nil(c.TX) + r.Equal(nctx, tx.Context()) + + r.NoError(tx.TX.Rollback()) + }) +} diff --git a/dialect_sqlite.go b/dialect_sqlite.go index d8400a50..35647c3b 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -182,22 +182,32 @@ func (m *sqlite) locker(l *sync.Mutex, fn func() error) error { } func (m *sqlite) CreateDB() error { - _, err := os.Stat(m.ConnectionDetails.Database) + durl := m.ConnectionDetails.Database + + // Checking whether the url specifies in-memory mode + // as specified in https://github.com/mattn/go-sqlite3#faq + if strings.Contains(durl, ":memory:") || strings.Contains(durl, "mode=memory") { + log(logging.Info, "in memory db selected, no database file created.") + + return nil + } + + _, err := os.Stat(durl) if err == nil { - return fmt.Errorf("could not create SQLite database '%s'; database exists", m.ConnectionDetails.Database) + return fmt.Errorf("could not create SQLite database '%s'; database exists", durl) } - dir := filepath.Dir(m.ConnectionDetails.Database) + dir := filepath.Dir(durl) err = os.MkdirAll(dir, 0766) if err != nil { - return fmt.Errorf("could not create SQLite database '%s': %w", m.ConnectionDetails.Database, err) + return fmt.Errorf("could not create SQLite database '%s': %w", durl, err) } - f, err := os.Create(m.ConnectionDetails.Database) + f, err := os.Create(durl) if err != nil { - return fmt.Errorf("could not create SQLite database '%s': %w", m.ConnectionDetails.Database, err) + return fmt.Errorf("could not create SQLite database '%s': %w", durl, err) } _ = f.Close() - log(logging.Info, "created database '%s'", m.ConnectionDetails.Database) + log(logging.Info, "created database '%s'", durl) return nil } diff --git a/dialect_sqlite_test.go b/dialect_sqlite_test.go index 83a0294a..ac5db21a 100644 --- a/dialect_sqlite_test.go +++ b/dialect_sqlite_test.go @@ -1,3 +1,4 @@ +//go:build sqlite // +build sqlite package pop @@ -144,18 +145,51 @@ func Test_ConnectionDetails_FinalizeOSPath(t *testing.T) { func TestSqlite_CreateDB(t *testing.T) { r := require.New(t) - dir := t.TempDir() - p := filepath.Join(dir, "testdb.sqlite") - cd := &ConnectionDetails{ - Dialect: "sqlite", - Database: p, - } + + cd := &ConnectionDetails{Dialect: "sqlite"} dialect, err := newSQLite(cd) r.NoError(err) - r.NoError(dialect.CreateDB()) - // Creating DB twice should produce an error - r.EqualError(dialect.CreateDB(), fmt.Sprintf("could not create SQLite database '%s'; database exists", p)) + t.Run("CreateFile", func(t *testing.T) { + dir := t.TempDir() + cd.Database = filepath.Join(dir, "testdb.sqlite") + + r.NoError(dialect.CreateDB()) + r.FileExists(cd.Database) + }) + + t.Run("MemoryDB_tag", func(t *testing.T) { + dir := t.TempDir() + cd.Database = filepath.Join(dir, "file::memory:?cache=shared") + + r.NoError(dialect.CreateDB()) + r.NoFileExists(cd.Database) + }) + + t.Run("MemoryDB_only", func(t *testing.T) { + dir := t.TempDir() + cd.Database = filepath.Join(dir, ":memory:") + + r.NoError(dialect.CreateDB()) + r.NoFileExists(cd.Database) + }) + + t.Run("MemoryDB_param", func(t *testing.T) { + dir := t.TempDir() + cd.Database = filepath.Join(dir, "file:foobar?mode=memory&cache=shared") + + r.NoError(dialect.CreateDB()) + r.NoFileExists(cd.Database) + }) + + t.Run("CreateFile_ExistingDB", func(t *testing.T) { + dir := t.TempDir() + cd.Database = filepath.Join(dir, "testdb.sqlite") + + r.NoError(dialect.CreateDB()) + r.EqualError(dialect.CreateDB(), fmt.Sprintf("could not create SQLite database '%s'; database exists", cd.Database)) + }) + } func TestSqlite_NewDriver(t *testing.T) {