Skip to content

Commit

Permalink
contrib/database/sql: fix support for drivers using deprecated interf…
Browse files Browse the repository at this point in the history
…aces (#1167)

Some drivers like the go-mssqldb driver do not implement newer database/sql
interfaces like Queryer/QueryerContext. Previously the code would
erroneously assume drivers that did not implement a QueryerContext
interface would implement the Queryer interface. This lead to panics like
in issue #1043.

Fixes #1043
  • Loading branch information
ajgajg1134 committed Feb 28, 2022
1 parent 41a7cd6 commit 55b07a6
Show file tree
Hide file tree
Showing 11 changed files with 426 additions and 47 deletions.
9 changes: 9 additions & 0 deletions .circleci/config.yml
Expand Up @@ -188,6 +188,10 @@ jobs:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
POSTGRES_DB: postgres
- image: mcr.microsoft.com/mssql/server:2019-latest
environment:
SA_PASSWORD: myPassw0rd
ACCEPT_EULA: Y
- image: consul:1.6.0
- image: redis:3.2
- image: elasticsearch:2
Expand Down Expand Up @@ -276,6 +280,7 @@ jobs:
# pin above.
go get gorm.io/driver/mysql@v1.2.3
go get gorm.io/driver/postgres@v1.2.3
go get gorm.io/driver/sqlserver@v1.2.1
go get github.com/zenazn/goji@v1.0.1
- run:
Expand All @@ -286,6 +291,10 @@ jobs:
name: Wait for Postgres
command: dockerize -wait tcp://localhost:5432 -timeout 1m

- run:
name: Wait for MS SQL Server
command: dockerize -wait tcp://localhost:1433 -timeout 1m

- run:
name: Wait for Redis
command: dockerize -wait tcp://localhost:6379 -timeout 1m
Expand Down
67 changes: 30 additions & 37 deletions contrib/database/sql/conn.go
Expand Up @@ -74,32 +74,28 @@ func (tc *tracedConn) PrepareContext(ctx context.Context, query string) (stmt dr
return &tracedStmt{stmt, tc.traceParams, ctx, query}, nil
}

func (tc *tracedConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if execer, ok := tc.Conn.(driver.Execer); ok {
return execer.Exec(query, args)
}
return nil, driver.ErrSkip
}

func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
start := time.Now()
if execContext, ok := tc.Conn.(driver.ExecerContext); ok {
r, err := execContext.ExecContext(ctx, query, args)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
}
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if execer, ok := tc.Conn.(driver.Execer); ok {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
r, err = execer.Exec(query, dargs)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
}
r, err = tc.Exec(query, dargs)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
return nil, driver.ErrSkip
}

// tracedConn has a Ping method in order to implement the pinger interface
Expand All @@ -112,32 +108,28 @@ func (tc *tracedConn) Ping(ctx context.Context) (err error) {
return err
}

func (tc *tracedConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if queryer, ok := tc.Conn.(driver.Queryer); ok {
return queryer.Query(query, args)
}
return nil, driver.ErrSkip
}

func (tc *tracedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
start := time.Now()
if queryerContext, ok := tc.Conn.(driver.QueryerContext); ok {
rows, err := queryerContext.QueryContext(ctx, query, args)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
}
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if queryer, ok := tc.Conn.(driver.Queryer); ok {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
rows, err = queryer.Query(query, dargs)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
}
rows, err = tc.Query(query, dargs)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
return nil, driver.ErrSkip
}

func (tc *tracedConn) CheckNamedValue(value *driver.NamedValue) error {
Expand All @@ -154,7 +146,8 @@ func (tc *tracedConn) ResetSession(ctx context.Context) error {
if resetter, ok := tc.Conn.(driver.SessionResetter); ok {
return resetter.ResetSession(ctx)
}
return driver.ErrSkip
// If driver doesn't implement driver.SessionResetter there's nothing to do
return nil
}

// traceParams stores all information related to tracing the driver.Conn
Expand Down
25 changes: 25 additions & 0 deletions contrib/database/sql/internal/dsn.go
Expand Up @@ -27,6 +27,11 @@ func ParseDSN(driverName, dsn string) (meta map[string]string, err error) {
if err != nil {
return
}
case "sqlserver":
meta, err = parseSQLServerDSN(dsn)
if err != nil {
return
}
default:
// not supported
}
Expand Down Expand Up @@ -86,3 +91,23 @@ func parsePostgresDSN(dsn string) (map[string]string, error) {
delete(meta, "password")
return meta, nil
}

// parseSQLServerDSN parses a sqlserver-type dsn into a map
func parseSQLServerDSN(dsn string) (map[string]string, error) {
var err error
var meta map[string]string
if strings.HasPrefix(dsn, "sqlserver://") {
// url form
meta, err = parseSQLServerURL(dsn)
if err != nil {
return nil, err
}
} else {
meta, err = parseSQLServerADO(dsn)
if err != nil {
return nil, err
}
}
delete(meta, "password")
return meta, nil
}
59 changes: 59 additions & 0 deletions contrib/database/sql/internal/dsn_test.go
Expand Up @@ -51,6 +51,16 @@ func TestParseDSN(t *testing.T) {
ext.DBUser: "dog",
},
},
{
driverName: "sqlserver",
dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb",
expected: map[string]string{
ext.DBUser: "bob",
ext.TargetHost: "1.2.3.4",
ext.TargetPort: "1433",
ext.DBName: "mydb",
},
},
} {
m, err := ParseDSN(tt.driverName, tt.dsn)
assert.Equal(nil, err)
Expand Down Expand Up @@ -104,3 +114,52 @@ func TestParsePostgresDSN(t *testing.T) {
assert.Equal(tt.expected, m)
}
}

func TestParseSqlServerDSN(t *testing.T) {
assert := assert.New(t)

for _, tt := range []struct {
dsn string
expected map[string]string
}{
{
dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb",
expected: map[string]string{
"user": "bob",
"host": "1.2.3.4",
"port": "1433",
"dbname": "mydb",
},
},
{
dsn: "sqlserver://alice:secret@localhost/SQLExpress?database=mydb",
expected: map[string]string{
"user": "alice",
"host": "localhost",
"dbname": "mydb",
"instanceName": "SQLExpress",
},
},
{
dsn: "server=1.2.3.4,1433;User Id=dog;Password=secret;Database=mydb;",
expected: map[string]string{
"user": "dog",
"port": "1433",
"host": "1.2.3.4",
"dbname": "mydb",
},
},
{
dsn: "ADDRESS=1.2.3.4;UID=cat;PASSWORD=secret;INITIAL CATALOG=mydb;",
expected: map[string]string{
"user": "cat",
"host": "1.2.3.4",
"dbname": "mydb",
},
},
} {
m, err := parseSQLServerDSN(tt.dsn)
assert.Equal(nil, err)
assert.Equal(tt.expected, m)
}
}
101 changes: 101 additions & 0 deletions contrib/database/sql/internal/sqlserver.go
@@ -0,0 +1,101 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016 Datadog, Inc.

package internal

import (
"fmt"
"net"
nurl "net/url"
"strings"
)

func parseSQLServerURL(url string) (map[string]string, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}

if u.Scheme != "sqlserver" {
return nil, fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}

kvs := map[string]string{}
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs[k] = escaper.Replace(v)
}
}

if u.User != nil {
v := u.User.Username()
accrue("user", v)
}

if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}

if u.Path != "" {
accrue("instanceName", u.Path[1:])
}

q := u.Query()
for k := range q {
if k == "database" {
accrue("dbname", q.Get(k))
}
}

return kvs, nil
}

var keySynonyms = map[string]string{
"server": "host",
"data source": "host",
"address": "host",
"network address": "host",
"addr": "host",
"uid": "user",
"user id": "user",
"initial catalog": "dbname",
"database": "dbname",
}

func parseSQLServerADO(dsn string) (map[string]string, error) {
kvs := map[string]string{}
fields := strings.Split(dsn, ";")
for _, f := range fields {
if len(f) == 0 {
continue
}
pts := strings.SplitN(f, "=", 2)
key := strings.TrimSpace(strings.ToLower(pts[0]))
if len(key) == 0 {
continue
}
val := ""
if len(pts) > 1 {
val = strings.TrimSpace(pts[1])
}
if synonym, found := keySynonyms[key]; found {
key = synonym
}
if key == "host" {
val = strings.TrimPrefix(val, "tcp:")
hostParts := strings.Split(val, ",")
if len(hostParts) == 2 && len(hostParts[1]) > 0 {
val = hostParts[0]
kvs["port"] = hostParts[1]
}
}
kvs[key] = val
}
return kvs, nil
}
27 changes: 27 additions & 0 deletions contrib/database/sql/sql_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"

mssql "github.com/denisenkom/go-mssqldb"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
Expand All @@ -38,6 +39,32 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestSqlServer(t *testing.T) {
Register("sqlserver", &mssql.Driver{})
db, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master")
if err != nil {
log.Fatal(err)
}
defer db.Close()

testConfig := &sqltest.Config{
DB: db,
DriverName: "sqlserver",
TableName: tableName,
ExpectName: "sqlserver.query",
ExpectTags: map[string]interface{}{
ext.ServiceName: "sqlserver.db",
ext.SpanType: ext.SpanTypeSQL,
ext.TargetHost: "127.0.0.1",
ext.TargetPort: "1433",
ext.DBUser: "sa",
ext.DBName: "master",
ext.EventSampleRate: nil,
},
}
sqltest.RunAll(t, testConfig)
}

func TestMySQL(t *testing.T) {
Register("mysql", &mysql.MySQLDriver{})
db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test")
Expand Down

0 comments on commit 55b07a6

Please sign in to comment.