diff --git a/database/sqlserver/README.md b/database/sqlserver/README.md index c4ef5a3a3..696feae75 100644 --- a/database/sqlserver/README.md +++ b/database/sqlserver/README.md @@ -17,6 +17,7 @@ | `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. | | `app+name` || The application name (default is go-mssqldb). | | `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported. +| `x-batch` | | Enable batch statements (default: false) | See https://github.com/denisenkom/go-mssqldb for full parameter list. diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 024001871..046f9e733 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/denisenkom/go-mssqldb/batch" "io" "io/ioutil" nurl "net/url" @@ -43,9 +44,10 @@ var lockErrorMap = map[mssql.ReturnStatus]string{ // Config for database type Config struct { - MigrationsTable string - DatabaseName string - SchemaName string + MigrationsTable string + DatabaseName string + SchemaName string + BatchStatementEnabled bool } // SQL Server connection @@ -169,9 +171,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") + batchStatementEnabled := false + if s := purl.Query().Get("x-batch"); len(s) > 0 { + batchStatementEnabled, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("Unable to parse option x-batch: %w", err) + } + } + px, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, - MigrationsTable: migrationsTable, + DatabaseName: purl.Path, + MigrationsTable: migrationsTable, + BatchStatementEnabled: batchStatementEnabled, }) if err != nil { @@ -242,15 +253,23 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) - if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { - if msErr, ok := err.(mssql.Error); ok { - message := fmt.Sprintf("migration failed: %s", msErr.Message) - if msErr.ProcName != "" { - message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + scripts := []string{query} + + if ss.config.BatchStatementEnabled { + scripts = batch.Split(query, "go") + } + + for _, script := range scripts { + if _, err := ss.conn.ExecContext(context.Background(), script); err != nil { + if msErr, ok := err.(mssql.Error); ok { + message := fmt.Sprintf("migration failed: %s", msErr.Message) + if msErr.ProcName != "" { + message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + } + return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)} } - return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} + return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)} } - return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } return nil diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index 2426629a8..030f90f6d 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -49,8 +49,9 @@ var ( } ) -func msConnectionString(host, port string) string { - return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port) +func msConnectionString(host, port string, options ...string) string { + options = append(options, "database=master") + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&")) } func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string { @@ -183,6 +184,49 @@ func TestMultiStatement(t *testing.T) { }) } +func TestBatchedStatement(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionString(ip, port, "x-batch=true") + ms := &SQLServer{} + d, err := ms.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA +AS +BEGIN + SELECT 1; +END; +GO +CREATE PROCEDURE uspB +AS +BEGIN + SELECT 2; +END`)); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure second proc exists + var exists int + if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil { + t.Fatal(err) + } + if exists != 1 { + t.Fatalf("expected proc uspB to exist") + } + }) +} + func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { SkipIfUnsupportedArch(t, c)