Skip to content

Commit

Permalink
Support MSSQL batch statements (Resolves #652)
Browse files Browse the repository at this point in the history
  • Loading branch information
glebteterin committed Dec 8, 2021
1 parent 0c500eb commit b934d03
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
18 changes: 11 additions & 7 deletions database/sqlserver/sqlserver.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/denisenkom/go-mssqldb/batch"
"io"
"io/ioutil"
nurl "net/url"
Expand Down Expand Up @@ -242,15 +243,18 @@ func (ss *SQLServer) Run(migration io.Reader) error {

// run migration
query := string(migr[:])
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
scripts := batch.Split(query, "go")
for _, script := range scripts {
if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}

return nil
Expand Down
43 changes: 43 additions & 0 deletions database/sqlserver/sqlserver_test.go
Expand Up @@ -159,6 +159,49 @@ func TestMultiStatement(t *testing.T) {
})
}

func TestBatchedStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := msConnectionString(ip, port)
ms := &SQLServer{}
d, err := ms.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA
AS
BEGIN
SELECT 1;
END;
GO
CREATE PROCEDURE uspB
AS
BEGIN
SELECT 2;
END`)); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

// make sure second proc exists
var exists int
if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil {
t.Fatal(err)
}
if exists != 1 {
t.Fatalf("expected proc uspB to exist")
}
})
}

func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down

0 comments on commit b934d03

Please sign in to comment.