Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement WithConnection for sqlserver database driver #794

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 36 additions & 13 deletions database/sqlserver/sqlserver.go
Expand Up @@ -58,22 +58,28 @@ type SQLServer struct {
config *Config
}

// WithInstance returns a database instance from an already created database connection.
// WithConnection returns a database driver instance from an already created database connection.
// The connection will be closed when the database driver is closed.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*SQLServer, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

ss := SQLServer{
conn: conn,
config: config,
}

if config.DatabaseName == "" {
query := `SELECT DB_NAME()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -87,7 +93,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT SCHEMA_NAME()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -102,22 +108,36 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())
if err := ss.ensureVersionTable(); err != nil {
return nil, err
}

if err != nil {
return &ss, nil
}

// WithInstance returns a database driver instance from an already created database handle.
// The database handle will be closed when the database driver is closed.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

ss := &SQLServer{
conn: conn,
db: instance,
config: config,
conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

if err := ss.ensureVersionTable(); err != nil {
ss, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}

ss.db = instance

return ss, nil
}

Expand Down Expand Up @@ -183,7 +203,10 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {
// Close the database connection
func (ss *SQLServer) Close() error {
connErr := ss.conn.Close()
dbErr := ss.db.Close()
var dbErr error
if ss.db != nil {
dbErr = ss.db.Close()
}
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down
47 changes: 47 additions & 0 deletions database/sqlserver/sqlserver_test.go
Expand Up @@ -120,6 +120,53 @@ func Test(t *testing.T) {
})
}

func TestWithConnection(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()

ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("sqlserver", msConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()

conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

p, err := WithConnection(ctx, conn, &Config{})
if err != nil {
t.Fatal(err)
}

defer func() {
if err := p.Close(); err != nil {
t.Error(err)
}
// Ensure connection is closed after database provider close
_, err := conn.QueryContext(ctx, "SELECT 1")
if err != sql.ErrConnDone {
t.Error("connection not closed")
}
_, err = db.QueryContext(ctx, "SELECT 1")
if err != nil {
t.Error("database handle should not be closed")
}
}()
dt.Test(t, p, []byte("SELECT 1"))
})
}

func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
SkipIfUnsupportedArch(t, c)
Expand Down