Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSSQL batch statements (Resolves #652) #666

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions database/sqlserver/README.md
Expand Up @@ -17,6 +17,7 @@
| `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. |
| `app+name` || The application name (default is go-mssqldb). |
| `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported.
| `x-batch` | | Enable batch statements (default: false) |

See https://github.com/denisenkom/go-mssqldb for full parameter list.

Expand Down
43 changes: 31 additions & 12 deletions database/sqlserver/sqlserver.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/denisenkom/go-mssqldb/batch"
"io"
"io/ioutil"
nurl "net/url"
Expand Down Expand Up @@ -43,9 +44,10 @@ var lockErrorMap = map[mssql.ReturnStatus]string{

// Config for database
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
MigrationsTable string
DatabaseName string
SchemaName string
BatchStatementEnabled bool
}

// SQL Server connection
Expand Down Expand Up @@ -169,9 +171,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {

migrationsTable := purl.Query().Get("x-migrations-table")

batchStatementEnabled := false
if s := purl.Query().Get("x-batch"); len(s) > 0 {
batchStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-batch: %w", err)
}
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
BatchStatementEnabled: batchStatementEnabled,
})

if err != nil {
Expand Down Expand Up @@ -242,15 +253,23 @@ func (ss *SQLServer) Run(migration io.Reader) error {

// run migration
query := string(migr[:])
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
scripts := []string{query}

if ss.config.BatchStatementEnabled {
scripts = batch.Split(query, "go")
}

for _, script := range scripts {
if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
}
return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)}
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}

return nil
Expand Down
48 changes: 46 additions & 2 deletions database/sqlserver/sqlserver_test.go
Expand Up @@ -49,8 +49,9 @@ var (
}
)

func msConnectionString(host, port string) string {
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
func msConnectionString(host, port string, options ...string) string {
options = append(options, "database=master")
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&"))
}

func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
Expand Down Expand Up @@ -183,6 +184,49 @@ func TestMultiStatement(t *testing.T) {
})
}

func TestBatchedStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}

addr := msConnectionString(ip, port, "x-batch=true")
ms := &SQLServer{}
d, err := ms.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA
AS
BEGIN
SELECT 1;
END;
GO
CREATE PROCEDURE uspB
AS
BEGIN
SELECT 2;
END`)); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

// make sure second proc exists
var exists int
if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil {
t.Fatal(err)
}
if exists != 1 {
t.Fatalf("expected proc uspB to exist")
}
})
}

func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
SkipIfUnsupportedArch(t, c)
Expand Down