Skip to content

Commit

Permalink
Make the MSSQL batch statements optional with x-batch
Browse files Browse the repository at this point in the history
  • Loading branch information
glebteterin committed Apr 20, 2022
1 parent de9e600 commit c63745f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
27 changes: 21 additions & 6 deletions database/sqlserver/sqlserver.go
Expand Up @@ -44,9 +44,10 @@ var lockErrorMap = map[mssql.ReturnStatus]string{

// Config for database
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
MigrationsTable string
DatabaseName string
SchemaName string
BatchStatementEnabled bool
}

// SQL Server connection
Expand Down Expand Up @@ -170,9 +171,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {

migrationsTable := purl.Query().Get("x-migrations-table")

batchStatementEnabled := false
if s := purl.Query().Get("x-batch"); len(s) > 0 {
batchStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-batch: %w", err)
}
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
BatchStatementEnabled: batchStatementEnabled,
})

if err != nil {
Expand Down Expand Up @@ -243,7 +253,12 @@ func (ss *SQLServer) Run(migration io.Reader) error {

// run migration
query := string(migr[:])
scripts := batch.Split(query, "go")
scripts := []string{query}

if ss.config.BatchStatementEnabled {
scripts = batch.Split(query, "go")
}

for _, script := range scripts {
if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
if msErr, ok := err.(mssql.Error); ok {
Expand Down
7 changes: 4 additions & 3 deletions database/sqlserver/sqlserver_test.go
Expand Up @@ -49,8 +49,9 @@ var (
}
)

func msConnectionString(host, port string) string {
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
func msConnectionString(host, port string, options ...string) string {
options = append(options, "database=master")
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&"))
}

func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
Expand Down Expand Up @@ -190,7 +191,7 @@ func TestBatchedStatement(t *testing.T) {
t.Fatal(err)
}

addr := msConnectionString(ip, port)
addr := msConnectionString(ip, port, "x-batch=true")
ms := &SQLServer{}
d, err := ms.Open(addr)
if err != nil {
Expand Down

0 comments on commit c63745f

Please sign in to comment.