diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index 8a85fc7ba3..dac44a143e 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -151,7 +151,7 @@ func (t dsnConnector) Driver() driver.Driver { return t.driver } -// OpenDB returns connection to a DB using a the traced version of the given driver. In order for OpenDB +// OpenDB returns connection to a DB using the traced version of the given driver. In order for OpenDB // to work, the driver must first be registered using Register. If this did not occur, OpenDB will panic. func OpenDB(c driver.Connector, opts ...Option) *sql.DB { name, ok := registeredDrivers.name(c.Driver()) @@ -179,7 +179,7 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB { return sql.OpenDB(tc) } -// Open returns connection to a DB using a the traced version of the given driver. In order for Open +// Open returns connection to a DB using the traced version of the given driver. In order for Open // to work, the driver must first be registered using Register. If this did not occur, Open will // return an error. func Open(driverName, dataSourceName string, opts ...Option) (*sql.DB, error) { diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index d19f92f7bc..2a3859b746 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -6,11 +6,16 @@ package sql import ( + "context" + "database/sql/driver" + "errors" "fmt" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "log" "math" "os" "testing" + "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/internal/sqltest" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -190,3 +195,56 @@ func TestMySQLUint64(t *testing.T) { assert.NoError(rows.Err()) assert.NoError(rows.Close()) } + +//hangingConnector hangs on Connect until ctx is cancelled. +type hangingConnector struct{} + +func (h *hangingConnector) Connect(ctx context.Context) (driver.Conn, error) { + select { + case <-ctx.Done(): + return &panicConn{}, errors.New("context cancelled") + } +} + +func (h *hangingConnector) Driver() driver.Driver { + panic("hangingConnector: Driver() not implemented") +} + +type panicConn struct{} + +func (p *panicConn) Prepare(query string) (driver.Stmt, error) { + panic("panicConn: Prepare called") +} + +func (p *panicConn) Close() error { + panic("panicConn: Close called") +} + +func (p *panicConn) Begin() (driver.Tx, error) { + panic("panicConn: Begin called") +} + +func TestConnectCancelledCtx(t *testing.T) { + mockTracer := mocktracer.Start() + defer mockTracer.Stop() + assert := assert.New(t) + tc := tracedConnector{ + connector: &hangingConnector{}, + driverName: "hangingConnector", + cfg: new(config), + } + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + tc.Connect(ctx) + }() + time.Sleep(time.Millisecond * 100) + cancelFunc() + time.Sleep(time.Millisecond * 100) + + spans := mockTracer.FinishedSpans() + assert.Len(spans, 1) + s := spans[0] + assert.Equal("hangingConnector.query", s.OperationName()) + assert.Equal("Connect", s.Tag("sql.query_type")) +}