diff --git a/contrib/database/sql/conn.go b/contrib/database/sql/conn.go index fa214bbe15..b95eec85ea 100644 --- a/contrib/database/sql/conn.go +++ b/contrib/database/sql/conn.go @@ -22,7 +22,8 @@ var _ driver.Conn = (*tracedConn)(nil) type queryType string const ( - queryTypeQuery queryType = "Query" + queryTypeConnect queryType = "Connect" + queryTypeQuery = "Query" queryTypePing = "Ping" queryTypePrepare = "Prepare" queryTypeExec = "Exec" diff --git a/contrib/database/sql/conn_test.go b/contrib/database/sql/conn_test.go index d08aa45d4e..3203bf0e27 100644 --- a/contrib/database/sql/conn_test.go +++ b/contrib/database/sql/conn_test.go @@ -92,9 +92,16 @@ func TestWithSpanTags(t *testing.T) { rows.Close() spans := mt.FinishedSpans() - assert.Len(t, spans, 1) + assert.Len(t, spans, 2) - span := spans[0] + connectSpan := spans[0] + assert.Equal(t, tt.want.opName, connectSpan.OperationName()) + assert.Equal(t, "Connect", connectSpan.Tag("sql.query_type")) + for k, v := range tt.want.ctxTags { + assert.Equal(t, v, connectSpan.Tag(k), "Value mismatch on tag %s", k) + } + + span := spans[1] assert.Equal(t, tt.want.opName, span.OperationName()) for k, v := range tt.want.ctxTags { assert.Equal(t, v, span.Tag(k), "Value mismatch on tag %s", k) diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index 755bcd9eff..2383f4a499 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -23,6 +23,7 @@ import ( "errors" "math" "reflect" + "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql/internal" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" @@ -128,11 +129,7 @@ type tracedConnector struct { cfg *config } -func (t *tracedConnector) Connect(c context.Context) (driver.Conn, error) { - conn, err := t.connector.Connect(c) - if err != nil { - return nil, err - } +func (t *tracedConnector) Connect(ctx context.Context) (driver.Conn, error) { tp := &traceParams{ driverName: t.driverName, cfg: t.cfg, @@ -142,6 +139,12 @@ func (t *tracedConnector) Connect(c context.Context) (driver.Conn, error) { } else if t.cfg.dsn != "" { tp.meta, _ = internal.ParseDSN(t.driverName, t.cfg.dsn) } + start := time.Now() + conn, err := t.connector.Connect(ctx) + tp.tryTrace(ctx, queryTypeConnect, "", start, err) + if err != nil { + return nil, err + } return &tracedConn{conn, tp}, err } @@ -163,7 +166,7 @@ func (t dsnConnector) Driver() driver.Driver { return t.driver } -// OpenDB returns connection to a DB using a the traced version of the given driver. In order for OpenDB +// OpenDB returns connection to a DB using the traced version of the given driver. In order for OpenDB // to work, the driver must first be registered using Register. If this did not occur, OpenDB will panic. func OpenDB(c driver.Connector, opts ...Option) *sql.DB { name, ok := registeredDrivers.name(c.Driver()) @@ -192,7 +195,7 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB { return sql.OpenDB(tc) } -// Open returns connection to a DB using a the traced version of the given driver. In order for Open +// Open returns connection to a DB using the traced version of the given driver. In order for Open // to work, the driver must first be registered using Register. If this did not occur, Open will // return an error. func Open(driverName, dataSourceName string, opts ...Option) (*sql.DB, error) { diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index d19f92f7bc..b953a9c04f 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -6,14 +6,19 @@ package sql import ( + "context" + "database/sql/driver" + "errors" "fmt" "log" "math" "os" "testing" + "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "github.com/go-sql-driver/mysql" "github.com/lib/pq" @@ -190,3 +195,42 @@ func TestMySQLUint64(t *testing.T) { assert.NoError(rows.Err()) assert.NoError(rows.Close()) } + +// hangingConnector hangs on Connect until ctx is cancelled. +type hangingConnector struct{} + +func (h *hangingConnector) Connect(ctx context.Context) (driver.Conn, error) { + select { + case <-ctx.Done(): + return nil, errors.New("context cancelled") + } +} + +func (h *hangingConnector) Driver() driver.Driver { + panic("hangingConnector: Driver() not implemented") +} + +func TestConnectCancelledCtx(t *testing.T) { + mockTracer := mocktracer.Start() + defer mockTracer.Stop() + assert := assert.New(t) + tc := tracedConnector{ + connector: &hangingConnector{}, + driverName: "hangingConnector", + cfg: new(config), + } + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + tc.Connect(ctx) + }() + time.Sleep(time.Millisecond * 100) + cancelFunc() + time.Sleep(time.Millisecond * 100) + + spans := mockTracer.FinishedSpans() + assert.Len(spans, 1) + s := spans[0] + assert.Equal("hangingConnector.query", s.OperationName()) + assert.Equal("Connect", s.Tag("sql.query_type")) +} diff --git a/contrib/internal/sqltest/sqltest.go b/contrib/internal/sqltest/sqltest.go index 7c9d64a5c4..0931ab5308 100644 --- a/contrib/internal/sqltest/sqltest.go +++ b/contrib/internal/sqltest/sqltest.go @@ -48,8 +48,10 @@ func Prepare(tableName string) func() { func RunAll(t *testing.T, cfg *Config) { cfg.mockTracer = mocktracer.Start() defer cfg.mockTracer.Stop() + cfg.DB.SetMaxIdleConns(0) for name, test := range map[string]func(*Config) func(*testing.T){ + "Connect": testConnect, "Ping": testPing, "Query": testQuery, "Statement": testStatement, @@ -60,17 +62,37 @@ func RunAll(t *testing.T, cfg *Config) { } } -func testPing(cfg *Config) func(*testing.T) { +func testConnect(cfg *Config) func(*testing.T) { return func(t *testing.T) { cfg.mockTracer.Reset() assert := assert.New(t) err := cfg.DB.Ping() assert.Nil(err) spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 1) + assert.Len(spans, 2) span := spans[0] assert.Equal(cfg.ExpectName, span.OperationName()) + cfg.ExpectTags["sql.query_type"] = "Connect" + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } + } +} + +func testPing(cfg *Config) func(*testing.T) { + return func(t *testing.T) { + cfg.mockTracer.Reset() + assert := assert.New(t) + err := cfg.DB.Ping() + assert.Nil(err) + spans := cfg.mockTracer.FinishedSpans() + assert.Len(spans, 2) + + verifyConnectSpan(spans[0], assert, cfg) + + span := spans[1] + assert.Equal(cfg.ExpectName, span.OperationName()) cfg.ExpectTags["sql.query_type"] = "Ping" for k, v := range cfg.ExpectTags { assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) @@ -88,9 +110,11 @@ func testQuery(cfg *Config) func(*testing.T) { assert.Nil(err) spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 1) + assert.Len(spans, 2) - span := spans[0] + verifyConnectSpan(spans[0], assert, cfg) + + span := spans[1] cfg.ExpectTags["sql.query_type"] = "Query" assert.Equal(cfg.ExpectName, span.OperationName()) for k, v := range cfg.ExpectTags { @@ -114,9 +138,11 @@ func testStatement(cfg *Config) func(*testing.T) { assert.Equal(nil, err) spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 1) + assert.Len(spans, 3) - span := spans[0] + verifyConnectSpan(spans[0], assert, cfg) + + span := spans[1] assert.Equal(cfg.ExpectName, span.OperationName()) cfg.ExpectTags["sql.query_type"] = "Prepare" for k, v := range cfg.ExpectTags { @@ -128,8 +154,8 @@ func testStatement(cfg *Config) func(*testing.T) { assert.Equal(nil, err2) spans = cfg.mockTracer.FinishedSpans() - assert.Len(spans, 1) - span = spans[0] + assert.Len(spans, 4) + span = spans[2] assert.Equal(cfg.ExpectName, span.OperationName()) cfg.ExpectTags["sql.query_type"] = "Exec" for k, v := range cfg.ExpectTags { @@ -147,9 +173,11 @@ func testBeginRollback(cfg *Config) func(*testing.T) { assert.Equal(nil, err) spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 1) + assert.Len(spans, 2) - span := spans[0] + verifyConnectSpan(spans[0], assert, cfg) + + span := spans[1] assert.Equal(cfg.ExpectName, span.OperationName()) cfg.ExpectTags["sql.query_type"] = "Begin" for k, v := range cfg.ExpectTags { @@ -192,7 +220,7 @@ func testExec(cfg *Config) func(*testing.T) { parent.Finish() // flush children spans := cfg.mockTracer.FinishedSpans() - assert.Len(spans, 4) + assert.Len(spans, 5) var span mocktracer.Span for _, s := range spans { @@ -218,6 +246,14 @@ func testExec(cfg *Config) func(*testing.T) { } } +func verifyConnectSpan(span mocktracer.Span, assert *assert.Assertions, cfg *Config) { + assert.Equal(cfg.ExpectName, span.OperationName()) + cfg.ExpectTags["sql.query_type"] = "Connect" + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } +} + // Config holds the test configuration. type Config struct { *sql.DB