Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contrib/database/sql: fix support for drivers using deprecated interfaces #1167

Merged
merged 3 commits into from Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions .circleci/config.yml
Expand Up @@ -188,6 +188,10 @@ jobs:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
POSTGRES_DB: postgres
- image: mcr.microsoft.com/mssql/server:2019-latest
environment:
SA_PASSWORD: myPassw0rd
ACCEPT_EULA: Y
- image: consul:1.6.0
- image: redis:3.2
- image: elasticsearch:2
Expand Down Expand Up @@ -276,6 +280,7 @@ jobs:
# pin above.
go get gorm.io/driver/mysql@v1.2.3
go get gorm.io/driver/postgres@v1.2.3
go get gorm.io/driver/sqlserver@v1.2.1
go get github.com/zenazn/goji@v1.0.1

- run:
Expand All @@ -286,6 +291,10 @@ jobs:
name: Wait for Postgres
command: dockerize -wait tcp://localhost:5432 -timeout 1m

- run:
name: Wait for MS SQL Server
command: dockerize -wait tcp://localhost:1433 -timeout 1m

- run:
name: Wait for Redis
command: dockerize -wait tcp://localhost:6379 -timeout 1m
Expand Down
67 changes: 30 additions & 37 deletions contrib/database/sql/conn.go
Expand Up @@ -74,32 +74,28 @@ func (tc *tracedConn) PrepareContext(ctx context.Context, query string) (stmt dr
return &tracedStmt{stmt, tc.traceParams, ctx, query}, nil
}

func (tc *tracedConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if execer, ok := tc.Conn.(driver.Execer); ok {
return execer.Exec(query, args)
}
return nil, driver.ErrSkip
}

func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
start := time.Now()
if execContext, ok := tc.Conn.(driver.ExecerContext); ok {
r, err := execContext.ExecContext(ctx, query, args)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
}
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if execer, ok := tc.Conn.(driver.Execer); ok {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
r, err = execer.Exec(query, dargs)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
}
r, err = tc.Exec(query, dargs)
tc.tryTrace(ctx, queryTypeExec, query, start, err)
return r, err
return nil, driver.ErrSkip
}

// tracedConn has a Ping method in order to implement the pinger interface
Expand All @@ -112,32 +108,28 @@ func (tc *tracedConn) Ping(ctx context.Context) (err error) {
return err
}

func (tc *tracedConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if queryer, ok := tc.Conn.(driver.Queryer); ok {
return queryer.Query(query, args)
}
return nil, driver.ErrSkip
}

func (tc *tracedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
start := time.Now()
if queryerContext, ok := tc.Conn.(driver.QueryerContext); ok {
rows, err := queryerContext.QueryContext(ctx, query, args)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
}
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if queryer, ok := tc.Conn.(driver.Queryer); ok {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
rows, err = queryer.Query(query, dargs)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
}
rows, err = tc.Query(query, dargs)
tc.tryTrace(ctx, queryTypeQuery, query, start, err)
return rows, err
return nil, driver.ErrSkip
}

func (tc *tracedConn) CheckNamedValue(value *driver.NamedValue) error {
Expand All @@ -154,7 +146,8 @@ func (tc *tracedConn) ResetSession(ctx context.Context) error {
if resetter, ok := tc.Conn.(driver.SessionResetter); ok {
return resetter.ResetSession(ctx)
}
return driver.ErrSkip
// If driver doesn't implement driver.SessionResetter there's nothing to do
return nil
}

// traceParams stores all information related to tracing the driver.Conn
Expand Down
25 changes: 25 additions & 0 deletions contrib/database/sql/internal/dsn.go
Expand Up @@ -27,6 +27,11 @@ func ParseDSN(driverName, dsn string) (meta map[string]string, err error) {
if err != nil {
return
}
case "sqlserver":
meta, err = parseSQLServerDSN(dsn)
if err != nil {
return
}
default:
// not supported
}
Expand Down Expand Up @@ -86,3 +91,23 @@ func parsePostgresDSN(dsn string) (map[string]string, error) {
delete(meta, "password")
return meta, nil
}

// parseSQLServerDSN parses a sqlserver-type dsn into a map
func parseSQLServerDSN(dsn string) (map[string]string, error) {
var err error
var meta map[string]string
if strings.HasPrefix(dsn, "sqlserver://") {
// url form
meta, err = parseSQLServerURL(dsn)
if err != nil {
return nil, err
}
} else {
meta, err = parseSQLServerADO(dsn)
if err != nil {
return nil, err
}
}
delete(meta, "password")
return meta, nil
}
59 changes: 59 additions & 0 deletions contrib/database/sql/internal/dsn_test.go
Expand Up @@ -51,6 +51,16 @@ func TestParseDSN(t *testing.T) {
ext.DBUser: "dog",
},
},
{
driverName: "sqlserver",
dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb",
expected: map[string]string{
ext.DBUser: "bob",
ext.TargetHost: "1.2.3.4",
ext.TargetPort: "1433",
ext.DBName: "mydb",
},
},
} {
m, err := ParseDSN(tt.driverName, tt.dsn)
assert.Equal(nil, err)
Expand Down Expand Up @@ -104,3 +114,52 @@ func TestParsePostgresDSN(t *testing.T) {
assert.Equal(tt.expected, m)
}
}

func TestParseSqlServerDSN(t *testing.T) {
assert := assert.New(t)

for _, tt := range []struct {
dsn string
expected map[string]string
}{
{
dsn: "sqlserver://bob:secret@1.2.3.4:1433?database=mydb",
expected: map[string]string{
"user": "bob",
"host": "1.2.3.4",
"port": "1433",
"dbname": "mydb",
},
},
{
dsn: "sqlserver://alice:secret@localhost/SQLExpress?database=mydb",
expected: map[string]string{
"user": "alice",
"host": "localhost",
"dbname": "mydb",
"instanceName": "SQLExpress",
},
},
{
dsn: "server=1.2.3.4,1433;User Id=dog;Password=secret;Database=mydb;",
expected: map[string]string{
"user": "dog",
"port": "1433",
"host": "1.2.3.4",
"dbname": "mydb",
},
},
{
dsn: "ADDRESS=1.2.3.4;UID=cat;PASSWORD=secret;INITIAL CATALOG=mydb;",
expected: map[string]string{
"user": "cat",
"host": "1.2.3.4",
"dbname": "mydb",
},
},
} {
m, err := parseSQLServerDSN(tt.dsn)
assert.Equal(nil, err)
assert.Equal(tt.expected, m)
}
}
101 changes: 101 additions & 0 deletions contrib/database/sql/internal/sqlserver.go
@@ -0,0 +1,101 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016 Datadog, Inc.

package internal

import (
"fmt"
"net"
nurl "net/url"
"strings"
)

func parseSQLServerURL(url string) (map[string]string, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}

if u.Scheme != "sqlserver" {
return nil, fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}

kvs := map[string]string{}
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs[k] = escaper.Replace(v)
}
}

if u.User != nil {
v := u.User.Username()
accrue("user", v)
}

if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}

if u.Path != "" {
accrue("instanceName", u.Path[1:])
}

q := u.Query()
for k := range q {
if k == "database" {
accrue("dbname", q.Get(k))
}
}

return kvs, nil
}

var keySynonyms = map[string]string{
"server": "host",
"data source": "host",
"address": "host",
"network address": "host",
"addr": "host",
"uid": "user",
"user id": "user",
"initial catalog": "dbname",
"database": "dbname",
}

func parseSQLServerADO(dsn string) (map[string]string, error) {
kvs := map[string]string{}
fields := strings.Split(dsn, ";")
for _, f := range fields {
if len(f) == 0 {
continue
}
pts := strings.SplitN(f, "=", 2)
key := strings.TrimSpace(strings.ToLower(pts[0]))
if len(key) == 0 {
continue
}
val := ""
if len(pts) > 1 {
val = strings.TrimSpace(pts[1])
}
if synonym, found := keySynonyms[key]; found {
key = synonym
}
if key == "host" {
val = strings.TrimPrefix(val, "tcp:")
hostParts := strings.Split(val, ",")
if len(hostParts) == 2 && len(hostParts[1]) > 0 {
val = hostParts[0]
kvs["port"] = hostParts[1]
}
}
kvs[key] = val
}
return kvs, nil
}
27 changes: 27 additions & 0 deletions contrib/database/sql/sql_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"

mssql "github.com/denisenkom/go-mssqldb"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
Expand All @@ -38,6 +39,32 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestSqlServer(t *testing.T) {
Register("sqlserver", &mssql.Driver{})
db, err := Open("sqlserver", "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master")
if err != nil {
log.Fatal(err)
}
defer db.Close()

testConfig := &sqltest.Config{
DB: db,
DriverName: "sqlserver",
TableName: tableName,
ExpectName: "sqlserver.query",
ExpectTags: map[string]interface{}{
ext.ServiceName: "sqlserver.db",
ext.SpanType: ext.SpanTypeSQL,
ext.TargetHost: "127.0.0.1",
ext.TargetPort: "1433",
ext.DBUser: "sa",
ext.DBName: "master",
ext.EventSampleRate: nil,
},
}
sqltest.RunAll(t, testConfig)
}

func TestMySQL(t *testing.T) {
Register("mysql", &mysql.MySQLDriver{})
db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test")
Expand Down