From 55b07a6b863b409ec54965a092f6cafd2b315a90 Mon Sep 17 00:00:00 2001 From: Andrew Glaude Date: Mon, 28 Feb 2022 13:21:52 -0500 Subject: [PATCH] contrib/database/sql: fix support for drivers using deprecated interfaces (#1167) Some drivers like the go-mssqldb driver do not implement newer database/sql interfaces like Queryer/QueryerContext. Previously the code would erroneously assume drivers that did not implement a QueryerContext interface would implement the Queryer interface. This lead to panics like in issue #1043. Fixes #1043 --- .circleci/config.yml | 9 ++ contrib/database/sql/conn.go | 67 ++++++------ contrib/database/sql/internal/dsn.go | 25 +++++ contrib/database/sql/internal/dsn_test.go | 59 +++++++++++ contrib/database/sql/internal/sqlserver.go | 101 +++++++++++++++++++ contrib/database/sql/sql_test.go | 27 +++++ contrib/gopkg.in/jinzhu/gorm.v1/gorm_test.go | 27 +++++ contrib/gorm.io/gorm.v1/gorm_test.go | 43 +++++++- contrib/internal/sqltest/sqltest.go | 62 ++++++++++-- contrib/jinzhu/gorm/gorm_test.go | 27 +++++ contrib/jmoiron/sqlx/sql_test.go | 26 +++++ 11 files changed, 426 insertions(+), 47 deletions(-) create mode 100644 contrib/database/sql/internal/sqlserver.go diff --git a/.circleci/config.yml b/.circleci/config.yml index f21bc86c38..bb6cdbdf30 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -188,6 +188,10 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres POSTGRES_DB: postgres + - image: mcr.microsoft.com/mssql/server:2019-latest + environment: + SA_PASSWORD: myPassw0rd + ACCEPT_EULA: Y - image: consul:1.6.0 - image: redis:3.2 - image: elasticsearch:2 @@ -276,6 +280,7 @@ jobs: # pin above. go get gorm.io/driver/mysql@v1.2.3 go get gorm.io/driver/postgres@v1.2.3 + go get gorm.io/driver/sqlserver@v1.2.1 go get github.com/zenazn/goji@v1.0.1 - run: @@ -286,6 +291,10 @@ jobs: name: Wait for Postgres command: dockerize -wait tcp://localhost:5432 -timeout 1m + - run: + name: Wait for MS SQL Server + command: dockerize -wait tcp://localhost:1433 -timeout 1m + - run: name: Wait for Redis command: dockerize -wait tcp://localhost:6379 -timeout 1m diff --git a/contrib/database/sql/conn.go b/contrib/database/sql/conn.go index b95eec85ea..d382301a5d 100644 --- a/contrib/database/sql/conn.go +++ b/contrib/database/sql/conn.go @@ -74,13 +74,6 @@ func (tc *tracedConn) PrepareContext(ctx context.Context, query string) (stmt dr return &tracedStmt{stmt, tc.traceParams, ctx, query}, nil } -func (tc *tracedConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if execer, ok := tc.Conn.(driver.Execer); ok { - return execer.Exec(query, args) - } - return nil, driver.ErrSkip -} - func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { start := time.Now() if execContext, ok := tc.Conn.(driver.ExecerContext); ok { @@ -88,18 +81,21 @@ func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driv tc.tryTrace(ctx, queryTypeExec, query, start, err) return r, err } - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + if execer, ok := tc.Conn.(driver.Execer); ok { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + r, err = execer.Exec(query, dargs) + tc.tryTrace(ctx, queryTypeExec, query, start, err) + return r, err } - r, err = tc.Exec(query, dargs) - tc.tryTrace(ctx, queryTypeExec, query, start, err) - return r, err + return nil, driver.ErrSkip } // tracedConn has a Ping method in order to implement the pinger interface @@ -112,13 +108,6 @@ func (tc *tracedConn) Ping(ctx context.Context) (err error) { return err } -func (tc *tracedConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if queryer, ok := tc.Conn.(driver.Queryer); ok { - return queryer.Query(query, args) - } - return nil, driver.ErrSkip -} - func (tc *tracedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { start := time.Now() if queryerContext, ok := tc.Conn.(driver.QueryerContext); ok { @@ -126,18 +115,21 @@ func (tc *tracedConn) QueryContext(ctx context.Context, query string, args []dri tc.tryTrace(ctx, queryTypeQuery, query, start, err) return rows, err } - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + if queryer, ok := tc.Conn.(driver.Queryer); ok { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + rows, err = queryer.Query(query, dargs) + tc.tryTrace(ctx, queryTypeQuery, query, start, err) + return rows, err } - rows, err = tc.Query(query, dargs) - tc.tryTrace(ctx, queryTypeQuery, query, start, err) - return rows, err + return nil, driver.ErrSkip } func (tc *tracedConn) CheckNamedValue(value *driver.NamedValue) error { @@ -154,7 +146,8 @@ func (tc *tracedConn) ResetSession(ctx context.Context) error { if resetter, ok := tc.Conn.(driver.SessionResetter); ok { return resetter.ResetSession(ctx) } - return driver.ErrSkip + // If driver doesn't implement driver.SessionResetter there's nothing to do + return nil } // traceParams stores all information related to tracing the driver.Conn diff --git a/contrib/database/sql/internal/dsn.go b/contrib/database/sql/internal/dsn.go index f087b4784a..da6f0da8a5 100644 --- a/contrib/database/sql/internal/dsn.go +++ b/contrib/database/sql/internal/dsn.go @@ -27,6 +27,11 @@ func ParseDSN(driverName, dsn string) (meta map[string]string, err error) { if err != nil { return } + case "sqlserver": + meta, err = parseSQLServerDSN(dsn) + if err != nil { + return + } default: // not supported } @@ -86,3 +91,23 @@ func parsePostgresDSN(dsn string) (map[string]string, error) { delete(meta, "password") return meta, nil } + +// parseSQLServerDSN parses a sqlserver-type dsn into a map +func parseSQLServerDSN(dsn string) (map[string]string, error) { + var err error + var meta map[string]string + if strings.HasPrefix(dsn, "sqlserver://") { + // url form + meta, err = parseSQLServerURL(dsn) + if err != nil { + return nil, err + } + } else { + meta, err = parseSQLServerADO(dsn) + if err != nil { + return nil, err + } + } + delete(meta, "password") + return meta, nil +} diff --git a/contrib/database/sql/internal/dsn_test.go b/contrib/database/sql/internal/dsn_test.go index ceba66317c..bd902b5fb9 100644 --- a/contrib/database/sql/internal/dsn_test.go +++ b/contrib/database/sql/internal/dsn_test.go @@ -51,6 +51,16 @@ func TestParseDSN(t *testing.T) { ext.DBUser: "dog", }, }, + { + driverName: "sqlserver", + dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb", + expected: map[string]string{ + ext.DBUser: "bob", + ext.TargetHost: "1.2.3.4", + ext.TargetPort: "1433", + ext.DBName: "mydb", + }, + }, } { m, err := ParseDSN(tt.driverName, tt.dsn) assert.Equal(nil, err) @@ -104,3 +114,52 @@ func TestParsePostgresDSN(t *testing.T) { assert.Equal(tt.expected, m) } } + +func TestParseSqlServerDSN(t *testing.T) { + assert := assert.New(t) + + for _, tt := range []struct { + dsn string + expected map[string]string + }{ + { + dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb", + expected: map[string]string{ + "user": "bob", + "host": "1.2.3.4", + "port": "1433", + "dbname": "mydb", + }, + }, + { + dsn: "sqlserver://alice:secret@localhost/SQLExpress?database=mydb", + expected: map[string]string{ + "user": "alice", + "host": "localhost", + "dbname": "mydb", + "instanceName": "SQLExpress", + }, + }, + { + dsn: "server=1.2.3.4,1433;User Id=dog;Password=secret;Database=mydb;", + expected: map[string]string{ + "user": "dog", + "port": "1433", + "host": "1.2.3.4", + "dbname": "mydb", + }, + }, + { + dsn: "ADDRESS=1.2.3.4;UID=cat;PASSWORD=secret;INITIAL CATALOG=mydb;", + expected: map[string]string{ + "user": "cat", + "host": "1.2.3.4", + "dbname": "mydb", + }, + }, + } { + m, err := parseSQLServerDSN(tt.dsn) + assert.Equal(nil, err) + assert.Equal(tt.expected, m) + } +} diff --git a/contrib/database/sql/internal/sqlserver.go b/contrib/database/sql/internal/sqlserver.go new file mode 100644 index 0000000000..d248f089d3 --- /dev/null +++ b/contrib/database/sql/internal/sqlserver.go @@ -0,0 +1,101 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package internal + +import ( + "fmt" + "net" + nurl "net/url" + "strings" +) + +func parseSQLServerURL(url string) (map[string]string, error) { + u, err := nurl.Parse(url) + if err != nil { + return nil, err + } + + if u.Scheme != "sqlserver" { + return nil, fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + kvs := map[string]string{} + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs[k] = escaper.Replace(v) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("instanceName", u.Path[1:]) + } + + q := u.Query() + for k := range q { + if k == "database" { + accrue("dbname", q.Get(k)) + } + } + + return kvs, nil +} + +var keySynonyms = map[string]string{ + "server": "host", + "data source": "host", + "address": "host", + "network address": "host", + "addr": "host", + "uid": "user", + "user id": "user", + "initial catalog": "dbname", + "database": "dbname", +} + +func parseSQLServerADO(dsn string) (map[string]string, error) { + kvs := map[string]string{} + fields := strings.Split(dsn, ";") + for _, f := range fields { + if len(f) == 0 { + continue + } + pts := strings.SplitN(f, "=", 2) + key := strings.TrimSpace(strings.ToLower(pts[0])) + if len(key) == 0 { + continue + } + val := "" + if len(pts) > 1 { + val = strings.TrimSpace(pts[1]) + } + if synonym, found := keySynonyms[key]; found { + key = synonym + } + if key == "host" { + val = strings.TrimPrefix(val, "tcp:") + hostParts := strings.Split(val, ",") + if len(hostParts) == 2 && len(hostParts[1]) > 0 { + val = hostParts[0] + kvs["port"] = hostParts[1] + } + } + kvs[key] = val + } + return kvs, nil +} diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index b953a9c04f..9022305927 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -20,6 +20,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" "github.com/lib/pq" "github.com/stretchr/testify/assert" @@ -38,6 +39,32 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func TestSqlServer(t *testing.T) { + Register("sqlserver", &mssql.Driver{}) + db, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + testConfig := &sqltest.Config{ + DB: db, + DriverName: "sqlserver", + TableName: tableName, + ExpectName: "sqlserver.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "sqlserver.db", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "1433", + ext.DBUser: "sa", + ext.DBName: "master", + ext.EventSampleRate: nil, + }, + } + sqltest.RunAll(t, testConfig) +} + func TestMySQL(t *testing.T) { Register("mysql", &mysql.MySQLDriver{}) db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test") diff --git a/contrib/gopkg.in/jinzhu/gorm.v1/gorm_test.go b/contrib/gopkg.in/jinzhu/gorm.v1/gorm_test.go index c6f1b67214..168013b74f 100644 --- a/contrib/gopkg.in/jinzhu/gorm.v1/gorm_test.go +++ b/contrib/gopkg.in/jinzhu/gorm.v1/gorm_test.go @@ -19,6 +19,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" + mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" "github.com/lib/pq" "github.com/stretchr/testify/assert" @@ -88,6 +89,32 @@ func TestPostgres(t *testing.T) { sqltest.RunAll(t, testConfig) } +func TestSqlServer(t *testing.T) { + sqltrace.Register("sqlserver", &mssql.Driver{}) + db, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + testConfig := &sqltest.Config{ + DB: db.DB(), + DriverName: "sqlserver", + TableName: tableName, + ExpectName: "sqlserver.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "sqlserver.db", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "1433", + ext.DBUser: "sa", + ext.DBName: "master", + ext.EventSampleRate: nil, + }, + } + sqltest.RunAll(t, testConfig) +} + type Product struct { gorm.Model Code string diff --git a/contrib/gorm.io/gorm.v1/gorm_test.go b/contrib/gorm.io/gorm.v1/gorm_test.go index 9d4d1fec19..1452d184d3 100644 --- a/contrib/gorm.io/gorm.v1/gorm_test.go +++ b/contrib/gorm.io/gorm.v1/gorm_test.go @@ -19,20 +19,23 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" + mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" "github.com/stretchr/testify/assert" mysqlgorm "gorm.io/driver/mysql" "gorm.io/driver/postgres" + "gorm.io/driver/sqlserver" "gorm.io/gorm" ) // tableName holds the SQL table that these tests will be run against. It must be unique cross-repo. const ( - tableName = "testgorm" - pgConnString = "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable" - mysqlConnString = "test:test@tcp(127.0.0.1:3306)/test" + tableName = "testgorm" + pgConnString = "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable" + sqlServerConnString = "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master" + mysqlConnString = "test:test@tcp(127.0.0.1:3306)/test" ) func TestMain(m *testing.M) { @@ -113,6 +116,40 @@ func TestPostgres(t *testing.T) { sqltest.RunAll(t, testConfig) } +func TestSQLServer(t *testing.T) { + sqltrace.Register("sqlserver", &mssql.Driver{}) + sqlDb, err := sqltrace.Open("sqlserver", sqlServerConnString) + if err != nil { + log.Fatal(err) + } + + db, err := Open(sqlserver.New(sqlserver.Config{Conn: sqlDb}), &gorm.Config{}) + if err != nil { + log.Fatal(err) + } + + internalDB, err := db.DB() + if err != nil { + log.Fatal(err) + } + + testConfig := &sqltest.Config{ + DB: internalDB, + DriverName: "sqlserver", + TableName: tableName, + ExpectName: "sqlserver.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "sqlserver.db", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "1433", + ext.DBUser: "sa", + ext.DBName: "master", + }, + } + sqltest.RunAll(t, testConfig) +} + type Product struct { gorm.Model Code string diff --git a/contrib/internal/sqltest/sqltest.go b/contrib/internal/sqltest/sqltest.go index 0931ab5308..35dbd1ddd0 100644 --- a/contrib/internal/sqltest/sqltest.go +++ b/contrib/internal/sqltest/sqltest.go @@ -38,9 +38,17 @@ func Prepare(tableName string) func() { } postgres.Exec(queryDrop) postgres.Exec(queryCreate) + mssql, err := sql.Open("sqlserver", "sqlserver://sa:myPassw0rd@localhost:1433?database=master") + defer mssql.Close() + if err != nil { + log.Fatal(err) + } + mssql.Exec(queryDrop) + mssql.Exec(queryCreate) return func() { mysql.Exec(queryDrop) postgres.Exec(queryDrop) + mssql.Exec(queryDrop) } } @@ -101,7 +109,13 @@ func testPing(cfg *Config) func(*testing.T) { } func testQuery(cfg *Config) func(*testing.T) { - query := fmt.Sprintf("SELECT id, name FROM %s LIMIT 5", cfg.TableName) + var query string + switch cfg.DriverName { + case "postgres", "pgx", "mysql": + query = fmt.Sprintf("SELECT id, name FROM %s LIMIT 5", cfg.TableName) + case "sqlserver": + query = fmt.Sprintf("SELECT TOP 5 id, name FROM %s", cfg.TableName) + } return func(t *testing.T) { cfg.mockTracer.Reset() assert := assert.New(t) @@ -110,15 +124,29 @@ func testQuery(cfg *Config) func(*testing.T) { assert.Nil(err) spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 2) + var querySpan mocktracer.Span + if cfg.DriverName == "sqlserver" { + //The mssql driver doesn't support non-prepared queries so there are 3 spans + //connect, prepare, and query + assert.Len(spans, 3) + span := spans[1] + cfg.ExpectTags["sql.query_type"] = "Prepare" + assert.Equal(cfg.ExpectName, span.OperationName()) + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } + querySpan = spans[2] - verifyConnectSpan(spans[0], assert, cfg) + } else { + assert.Len(spans, 2) + querySpan = spans[1] + } - span := spans[1] + verifyConnectSpan(spans[0], assert, cfg) cfg.ExpectTags["sql.query_type"] = "Query" - assert.Equal(cfg.ExpectName, span.OperationName()) + assert.Equal(cfg.ExpectName, querySpan.OperationName()) for k, v := range cfg.ExpectTags { - assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + assert.Equal(v, querySpan.Tag(k), "Value mismatch on tag %s", k) } } } @@ -130,6 +158,8 @@ func testStatement(cfg *Config) func(*testing.T) { query = fmt.Sprintf(query, cfg.TableName, "$1") case "mysql": query = fmt.Sprintf(query, cfg.TableName, "?") + case "sqlserver": + query = fmt.Sprintf(query, cfg.TableName, "@p1") } return func(t *testing.T) { cfg.mockTracer.Reset() @@ -220,7 +250,25 @@ func testExec(cfg *Config) func(*testing.T) { parent.Finish() // flush children spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 5) + if cfg.DriverName == "sqlserver" { + //The mssql driver doesn't support non-prepared exec so there are 2 extra spans for the exec: + //prepare, exec, and then a close + assert.Len(spans, 7) + span := spans[2] + cfg.ExpectTags["sql.query_type"] = "Prepare" + assert.Equal(cfg.ExpectName, span.OperationName()) + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } + span = spans[4] + cfg.ExpectTags["sql.query_type"] = "Close" + assert.Equal(cfg.ExpectName, span.OperationName()) + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } + } else { + assert.Len(spans, 5) + } var span mocktracer.Span for _, s := range spans { diff --git a/contrib/jinzhu/gorm/gorm_test.go b/contrib/jinzhu/gorm/gorm_test.go index 70b48ae017..436b779451 100644 --- a/contrib/jinzhu/gorm/gorm_test.go +++ b/contrib/jinzhu/gorm/gorm_test.go @@ -19,6 +19,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" + mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/lib/pq" @@ -38,6 +39,32 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func TestSqlServer(t *testing.T) { + sqltrace.Register("sqlserver", &mssql.Driver{}) + db, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + testConfig := &sqltest.Config{ + DB: db.DB(), + DriverName: "sqlserver", + TableName: tableName, + ExpectName: "sqlserver.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "sqlserver.db", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "1433", + ext.DBUser: "sa", + ext.DBName: "master", + ext.EventSampleRate: nil, + }, + } + sqltest.RunAll(t, testConfig) +} + func TestMySQL(t *testing.T) { sqltrace.Register("mysql", &mysql.MySQLDriver{}, sqltrace.WithServiceName("mysql-test")) db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test") diff --git a/contrib/jmoiron/sqlx/sql_test.go b/contrib/jmoiron/sqlx/sql_test.go index a2ca3be917..54322f839a 100644 --- a/contrib/jmoiron/sqlx/sql_test.go +++ b/contrib/jmoiron/sqlx/sql_test.go @@ -15,6 +15,7 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + mssql "github.com/denisenkom/go-mssqldb" "github.com/go-sql-driver/mysql" "github.com/lib/pq" ) @@ -82,6 +83,31 @@ func TestPostgres(t *testing.T) { sqltest.RunAll(t, testConfig) } +func TestSQLServer(t *testing.T) { + sqltrace.Register("sqlserver", &mssql.Driver{}) + dbx, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master") + if err != nil { + log.Fatal(err) + } + defer dbx.Close() + + testConfig := &sqltest.Config{ + DB: dbx.DB, + DriverName: "sqlserver", + TableName: tableName, + ExpectName: "sqlserver.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "sqlserver.db", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "1433", + ext.DBUser: "sa", + ext.DBName: "master", + }, + } + sqltest.RunAll(t, testConfig) +} + func TestOpenWithOptions(t *testing.T) { sqltrace.Register("mysql", &mysql.MySQLDriver{}, sqltrace.WithServiceName("mysql-test")) dbx, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test", sqltrace.WithServiceName("other-service"))