diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 24b2c2345..046f9e733 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -44,9 +44,10 @@ var lockErrorMap = map[mssql.ReturnStatus]string{ // Config for database type Config struct { - MigrationsTable string - DatabaseName string - SchemaName string + MigrationsTable string + DatabaseName string + SchemaName string + BatchStatementEnabled bool } // SQL Server connection @@ -170,9 +171,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") + batchStatementEnabled := false + if s := purl.Query().Get("x-batch"); len(s) > 0 { + batchStatementEnabled, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("Unable to parse option x-batch: %w", err) + } + } + px, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, - MigrationsTable: migrationsTable, + DatabaseName: purl.Path, + MigrationsTable: migrationsTable, + BatchStatementEnabled: batchStatementEnabled, }) if err != nil { @@ -243,7 +253,12 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) - scripts := batch.Split(query, "go") + scripts := []string{query} + + if ss.config.BatchStatementEnabled { + scripts = batch.Split(query, "go") + } + for _, script := range scripts { if _, err := ss.conn.ExecContext(context.Background(), script); err != nil { if msErr, ok := err.(mssql.Error); ok { diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index 41b9700a6..57cd56ddd 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -49,8 +49,9 @@ var ( } ) -func msConnectionString(host, port string) string { - return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port) +func msConnectionString(host, port string, options ...string) string { + options = append(options, "database=master") + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&")) } func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string { @@ -190,7 +191,7 @@ func TestBatchedStatement(t *testing.T) { t.Fatal(err) } - addr := msConnectionString(ip, port) + addr := msConnectionString(ip, port, "x-batch=true") ms := &SQLServer{} d, err := ms.Open(addr) if err != nil {