diff --git a/contrib/database/sql/conn.go b/contrib/database/sql/conn.go index e32b4a1dea..ac0405b3a6 100644 --- a/contrib/database/sql/conn.go +++ b/contrib/database/sql/conn.go @@ -22,7 +22,8 @@ var _ driver.Conn = (*tracedConn)(nil) type queryType string const ( - queryTypeQuery queryType = "Query" + queryTypeConnect queryType = "Connect" + queryTypeQuery = "Query" queryTypePing = "Ping" queryTypePrepare = "Prepare" queryTypeExec = "Exec" diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index d80fa41869..dac44a143e 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -23,6 +23,7 @@ import ( "errors" "math" "reflect" + "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql/internal" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" @@ -113,11 +114,7 @@ type tracedConnector struct { cfg *config } -func (t *tracedConnector) Connect(c context.Context) (driver.Conn, error) { - conn, err := t.connector.Connect(c) - if err != nil { - return nil, err - } +func (t *tracedConnector) Connect(ctx context.Context) (driver.Conn, error) { tp := &traceParams{ driverName: t.driverName, cfg: t.cfg, @@ -127,6 +124,12 @@ func (t *tracedConnector) Connect(c context.Context) (driver.Conn, error) { } else if t.cfg.dsn != "" { tp.meta, _ = internal.ParseDSN(t.driverName, t.cfg.dsn) } + start := time.Now() + conn, err := t.connector.Connect(ctx) + tp.tryTrace(ctx, queryTypeConnect, "", start, err) + if err != nil { + return nil, err + } return &tracedConn{conn, tp}, err } @@ -148,7 +151,7 @@ func (t dsnConnector) Driver() driver.Driver { return t.driver } -// OpenDB returns connection to a DB using a the traced version of the given driver. In order for OpenDB +// OpenDB returns connection to a DB using the traced version of the given driver. In order for OpenDB // to work, the driver must first be registered using Register. If this did not occur, OpenDB will panic. func OpenDB(c driver.Connector, opts ...Option) *sql.DB { name, ok := registeredDrivers.name(c.Driver()) @@ -176,7 +179,7 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB { return sql.OpenDB(tc) } -// Open returns connection to a DB using a the traced version of the given driver. In order for Open +// Open returns connection to a DB using the traced version of the given driver. In order for Open // to work, the driver must first be registered using Register. If this did not occur, Open will // return an error. func Open(driverName, dataSourceName string, opts ...Option) (*sql.DB, error) { diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index d19f92f7bc..2a3859b746 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -6,11 +6,16 @@ package sql import ( + "context" + "database/sql/driver" + "errors" "fmt" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "log" "math" "os" "testing" + "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -190,3 +195,56 @@ func TestMySQLUint64(t *testing.T) { assert.NoError(rows.Err()) assert.NoError(rows.Close()) } + +//hangingConnector hangs on Connect until ctx is cancelled. +type hangingConnector struct{} + +func (h *hangingConnector) Connect(ctx context.Context) (driver.Conn, error) { + select { + case <-ctx.Done(): + return &panicConn{}, errors.New("context cancelled") + } +} + +func (h *hangingConnector) Driver() driver.Driver { + panic("hangingConnector: Driver() not implemented") +} + +type panicConn struct{} + +func (p *panicConn) Prepare(query string) (driver.Stmt, error) { + panic("panicConn: Prepare called") +} + +func (p *panicConn) Close() error { + panic("panicConn: Close called") +} + +func (p *panicConn) Begin() (driver.Tx, error) { + panic("panicConn: Begin called") +} + +func TestConnectCancelledCtx(t *testing.T) { + mockTracer := mocktracer.Start() + defer mockTracer.Stop() + assert := assert.New(t) + tc := tracedConnector{ + connector: &hangingConnector{}, + driverName: "hangingConnector", + cfg: new(config), + } + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + tc.Connect(ctx) + }() + time.Sleep(time.Millisecond * 100) + cancelFunc() + time.Sleep(time.Millisecond * 100) + + spans := mockTracer.FinishedSpans() + assert.Len(spans, 1) + s := spans[0] + assert.Equal("hangingConnector.query", s.OperationName()) + assert.Equal("Connect", s.Tag("sql.query_type")) +} diff --git a/contrib/internal/sqltest/sqltest.go b/contrib/internal/sqltest/sqltest.go index 7c9d64a5c4..4aa1758d0c 100644 --- a/contrib/internal/sqltest/sqltest.go +++ b/contrib/internal/sqltest/sqltest.go @@ -49,6 +49,8 @@ func RunAll(t *testing.T, cfg *Config) { cfg.mockTracer = mocktracer.Start() defer cfg.mockTracer.Stop() + // Make sure testConnect runs first to ensure a connection is established + t.Run("Connect", testConnect(cfg)) for name, test := range map[string]func(*Config) func(*testing.T){ "Ping": testPing, "Query": testQuery, @@ -60,6 +62,24 @@ func RunAll(t *testing.T, cfg *Config) { } } +func testConnect(cfg *Config) func(*testing.T) { + return func(t *testing.T) { + cfg.mockTracer.Reset() + assert := assert.New(t) + err := cfg.DB.Ping() + assert.Nil(err) + spans := cfg.mockTracer.FinishedSpans() + assert.Len(spans, 2) + + span := spans[0] + assert.Equal(cfg.ExpectName, span.OperationName()) + cfg.ExpectTags["sql.query_type"] = "Connect" + for k, v := range cfg.ExpectTags { + assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) + } + } +} + func testPing(cfg *Config) func(*testing.T) { return func(t *testing.T) { cfg.mockTracer.Reset()