diff --git a/contrib/database/sql/conn.go b/contrib/database/sql/conn.go index cf2236082c..6a073f6a41 100644 --- a/contrib/database/sql/conn.go +++ b/contrib/database/sql/conn.go @@ -43,7 +43,7 @@ func (tc *tracedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dr start := time.Now() if connBeginTx, ok := tc.Conn.(driver.ConnBeginTx); ok { tx, err = connBeginTx.BeginTx(ctx, opts) - span := tc.tryStartTrace(ctx, queryTypeBegin, "", start, &tracer.SQLCommentCarrier{}, err) + span := tc.tryStartTrace(ctx, queryTypeBegin, "", start, nil, err) if span != nil { defer func() { span.Finish(tracer.WithError(err)) @@ -55,7 +55,7 @@ func (tc *tracedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dr return &tracedTx{tx, tc.traceParams, ctx}, nil } tx, err = tc.Conn.Begin() - span := tc.tryStartTrace(ctx, queryTypeBegin, "", start, &tracer.SQLCommentCarrier{}, err) + span := tc.tryStartTrace(ctx, queryTypeBegin, "", start, nil, err) if span != nil { defer func() { span.Finish(tracer.WithError(err)) @@ -84,7 +84,7 @@ func (tc *tracedConn) PrepareContext(ctx context.Context, query string) (stmt dr return &tracedStmt{Stmt: stmt, traceParams: tc.traceParams, ctx: ctx, query: query}, nil } - sqlCommentCarrier := tracer.SQLCommentCarrier{} + sqlCommentCarrier := tracer.SQLCommentCarrier{DiscardDynamicTags: true} span := tc.tryStartTrace(ctx, queryTypePrepare, query, start, &sqlCommentCarrier, err) if span != nil { go func() { @@ -139,7 +139,7 @@ func (tc *tracedConn) ExecContext(ctx context.Context, query string, args []driv func (tc *tracedConn) Ping(ctx context.Context) (err error) { start := time.Now() if pinger, ok := tc.Conn.(driver.Pinger); ok { - span := tc.tryStartTrace(ctx, queryTypePing, "", start, &tracer.SQLCommentCarrier{}, err) + span := tc.tryStartTrace(ctx, queryTypePing, "", start, nil, err) if span != nil { go func() { span.Finish(tracer.WithError(err)) @@ -225,7 +225,6 @@ func WithSpanTags(ctx context.Context, tags map[string]string) context.Context { // tryStartTrace will create a span using the given arguments, but will act as a no-op when err is driver.ErrSkip. func (tp *traceParams) tryStartTrace(ctx context.Context, qtype queryType, query string, startTime time.Time, sqlCommentCarrier *tracer.SQLCommentCarrier, err error) (span tracer.Span) { - fmt.Printf("Executing query type %s\n", qtype) if err == driver.ErrSkip { // Not a user error: driver is telling sql package that an // optional interface method is not implemented. There is @@ -261,7 +260,7 @@ func (tp *traceParams) tryStartTrace(ctx context.Context, qtype queryType, query } } - if tp.cfg.commentInjectionMode != commentInjectionDisabled { + if sqlCommentCarrier != nil && tp.cfg.commentInjectionMode != commentInjectionDisabled { injectionOpts := injectionOptionsForMode(tp.cfg.commentInjectionMode, sqlCommentCarrier.DiscardDynamicTags) err = tracer.InjectWithOptions(span.Context(), sqlCommentCarrier, injectionOpts...) if err != nil { diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index 4683111f31..85e6313e0e 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -142,7 +142,7 @@ func (t *tracedConnector) Connect(ctx context.Context) (c driver.Conn, err error tp.meta, _ = internal.ParseDSN(t.driverName, t.cfg.dsn) } start := time.Now() - span := tp.tryStartTrace(ctx, queryTypeConnect, "", start, &tracer.SQLCommentCarrier{}, err) + span := tp.tryStartTrace(ctx, queryTypeConnect, "", start, nil, err) if span != nil { go func() { span.Finish(tracer.WithError(err)) @@ -193,6 +193,9 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB { if math.IsNaN(cfg.analyticsRate) { cfg.analyticsRate = rc.analyticsRate } + if cfg.commentInjectionMode == 0 { + cfg.commentInjectionMode = rc.commentInjectionMode + } cfg.childSpansOnly = rc.childSpansOnly tc := &tracedConnector{ connector: c, diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index 9022305927..1785112526 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -202,6 +202,81 @@ func TestOpenOptions(t *testing.T) { }) } +func TestCommentInjectionModes(t *testing.T) { + testCases := []struct { + name string + options []Option + expectedInjectedTags sqltest.TagInjectionExpectation + }{ + { + name: "default (no injection)", + options: []Option{}, + expectedInjectedTags: sqltest.TagInjectionExpectation{ + StaticTags: false, + DynamicTags: false, + }, + }, + { + name: "explicit no injection", + options: []Option{WithoutCommentInjection()}, + expectedInjectedTags: sqltest.TagInjectionExpectation{ + StaticTags: false, + DynamicTags: false, + }, + }, + { + name: "static tags injection", + options: []Option{WithStaticTagsCommentInjection()}, + expectedInjectedTags: sqltest.TagInjectionExpectation{ + StaticTags: true, + DynamicTags: false, + }, + }, + { + name: "dynamic tags injection", + options: []Option{WithCommentInjection()}, + expectedInjectedTags: sqltest.TagInjectionExpectation{ + StaticTags: true, + DynamicTags: true, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockTracer := mocktracer.Start() + defer mockTracer.Stop() + + Register("postgres", &pq.Driver{}, append(tc.options, WithServiceName("postgres-test"))...) + defer unregister("postgres") + + db, err := Open("postgres", "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + testConfig := &sqltest.Config{ + DB: db, + DriverName: "postgres", + TableName: tableName, + ExpectName: "postgres.query", + ExpectTags: map[string]interface{}{ + ext.ServiceName: "postgres-test", + ext.SpanType: ext.SpanTypeSQL, + ext.TargetHost: "127.0.0.1", + ext.TargetPort: "5432", + ext.DBUser: "postgres", + ext.DBName: "postgres", + }, + ExpectTagInjection: tc.expectedInjectedTags, + } + + sqltest.RunAll(t, testConfig) + }) + } +} + func TestMySQLUint64(t *testing.T) { Register("mysql", &mysql.MySQLDriver{}) db, err := Open("mysql", "test:test@tcp(127.0.0.1:3306)/test") diff --git a/contrib/database/sql/stmt.go b/contrib/database/sql/stmt.go index bf07ce7e82..0c8c6c99b3 100644 --- a/contrib/database/sql/stmt.go +++ b/contrib/database/sql/stmt.go @@ -27,7 +27,7 @@ type tracedStmt struct { // Close sends a span before closing a statement func (s *tracedStmt) Close() (err error) { start := time.Now() - span := s.tryStartTrace(s.ctx, queryTypeClose, "", start, &tracer.SQLCommentCarrier{}, err) + span := s.tryStartTrace(s.ctx, queryTypeClose, "", start, nil, err) if span != nil { go func() { span.Finish(tracer.WithError(err)) diff --git a/contrib/database/sql/tx.go b/contrib/database/sql/tx.go index 4c43264067..4d0497821f 100644 --- a/contrib/database/sql/tx.go +++ b/contrib/database/sql/tx.go @@ -25,7 +25,7 @@ type tracedTx struct { // Commit sends a span at the end of the transaction func (t *tracedTx) Commit() (err error) { start := time.Now() - span := t.tryStartTrace(t.ctx, queryTypeCommit, "", start, &tracer.SQLCommentCarrier{}, err) + span := t.tryStartTrace(t.ctx, queryTypeCommit, "", start, nil, err) if span != nil { go func() { span.Finish(tracer.WithError(err)) @@ -38,7 +38,7 @@ func (t *tracedTx) Commit() (err error) { // Rollback sends a span if the connection is aborted func (t *tracedTx) Rollback() (err error) { start := time.Now() - span := t.tryStartTrace(t.ctx, queryTypeRollback, "", start, &tracer.SQLCommentCarrier{}, err) + span := t.tryStartTrace(t.ctx, queryTypeRollback, "", start, nil, err) if span != nil { go func() { span.Finish(tracer.WithError(err)) diff --git a/contrib/internal/sqltest/sqltest.go b/contrib/internal/sqltest/sqltest.go index 07db4e39e5..0d236285c9 100644 --- a/contrib/internal/sqltest/sqltest.go +++ b/contrib/internal/sqltest/sqltest.go @@ -150,6 +150,8 @@ func testQuery(cfg *Config) func(*testing.T) { for k, v := range cfg.ExpectTags { assert.Equal(v, querySpan.Tag(k), "Value mismatch on tag %s", k) } + + assertInjectedComments(t, cfg, false) } } @@ -181,6 +183,8 @@ func testStatement(cfg *Config) func(*testing.T) { assert.Equal(v, span.Tag(k), "Value mismatch on tag %s", k) } + assertInjectedComments(t, cfg, true) + cfg.mockTracer.Reset() _, err2 := stmt.Exec("New York") assert.Equal(nil, err2) @@ -288,6 +292,8 @@ func testExec(cfg *Config) func(*testing.T) { span = s } } + assertInjectedComments(t, cfg, false) + assert.NotNil(span, "span not found") cfg.ExpectTags["sql.query_type"] = "Commit" for k, v := range cfg.ExpectTags { @@ -296,6 +302,37 @@ func testExec(cfg *Config) func(*testing.T) { } } +func assertInjectedComments(t *testing.T, cfg *Config, discardDynamicTags bool) { + c := cfg.mockTracer.InjectedComments() + carrier := tracer.SQLCommentCarrier{} + for k, v := range expectedInjectedTags(cfg, discardDynamicTags) { + carrier.Set(k, v) + } + + if !cfg.ExpectTagInjection.StaticTags && !cfg.ExpectTagInjection.DynamicTags { + assert.Len(t, c, 0) + } else { + require.Len(t, c, 1) + assert.Equal(t, carrier.CommentedQuery(""), c[0]) + } +} + +func expectedInjectedTags(cfg *Config, discardDynamicTags bool) map[string]string { + expectedInjectedTags := make(map[string]string) + // Prepare statements should never have dynamic tags injected so we only check if static tags are expected + if cfg.ExpectTagInjection.StaticTags { + expectedInjectedTags[tracer.ServiceNameSQLCommentKey] = "test-service" + expectedInjectedTags[tracer.ServiceEnvironmentSQLCommentKey] = "test-env" + expectedInjectedTags[tracer.ServiceVersionSQLCommentKey] = "v-test" + } + if cfg.ExpectTagInjection.DynamicTags && !discardDynamicTags { + expectedInjectedTags[tracer.SamplingPrioritySQLCommentKey] = "0" + expectedInjectedTags[tracer.TraceIDSQLCommentKey] = "test-trace-id" + expectedInjectedTags[tracer.SpanIDSQLCommentKey] = "test-span-id" + } + return expectedInjectedTags +} + func verifyConnectSpan(span mocktracer.Span, assert *assert.Assertions, cfg *Config) { assert.Equal(cfg.ExpectName, span.OperationName()) cfg.ExpectTags["sql.query_type"] = "Connect" @@ -304,12 +341,19 @@ func verifyConnectSpan(span mocktracer.Span, assert *assert.Assertions, cfg *Con } } +// TagInjectionExpectation holds expectations relating to tag injection +type TagInjectionExpectation struct { + StaticTags bool + DynamicTags bool +} + // Config holds the test configuration. type Config struct { *sql.DB - mockTracer mocktracer.Tracer - DriverName string - TableName string - ExpectName string - ExpectTags map[string]interface{} + mockTracer mocktracer.Tracer + DriverName string + TableName string + ExpectName string + ExpectTags map[string]interface{} + ExpectTagInjection TagInjectionExpectation } diff --git a/ddtrace/mocktracer/mockspan.go b/ddtrace/mocktracer/mockspan.go index 39a91b703b..7d95264403 100644 --- a/ddtrace/mocktracer/mockspan.go +++ b/ddtrace/mocktracer/mockspan.go @@ -118,6 +118,9 @@ func (s *mockspan) SetTag(key string, value interface{}) { if s.tags == nil { s.tags = make(map[string]interface{}, 1) } + if key == "sql.query_type" { + fmt.Printf("New span for type %s\n", value) + } if key == ext.SamplingPriority { switch p := value.(type) { case int: @@ -188,6 +191,7 @@ func (s *mockspan) SetBaggageItem(key, val string) { // Finish finishes the current span with the given options. func (s *mockspan) Finish(opts ...ddtrace.FinishOption) { + fmt.Printf("Finishing span with type %v\n", s.Tag("sql.query_type")) var cfg ddtrace.FinishConfig for _, fn := range opts { fn(&cfg) @@ -212,6 +216,7 @@ func (s *mockspan) Finish(opts ...ddtrace.FinishOption) { s.finished = true s.finishTime = t s.tracer.addFinishedSpan(s) + //fmt.Printf("Finished span with type %v\n", s.Tag("sql.query_type")) } // String implements fmt.Stringer. diff --git a/ddtrace/mocktracer/mocktracer.go b/ddtrace/mocktracer/mocktracer.go index 38caf31404..d2555bd55e 100644 --- a/ddtrace/mocktracer/mocktracer.go +++ b/ddtrace/mocktracer/mocktracer.go @@ -13,12 +13,11 @@ package mocktracer import ( + "fmt" "strconv" "strings" "sync" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/internal" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" @@ -105,6 +104,9 @@ func (t *mocktracer) OpenSpans() []Span { func (t *mocktracer) FinishedSpans() []Span { t.RLock() defer t.RUnlock() + for _, s := range t.finishedSpans { + fmt.Printf("returning finished span of type %v\n", s.Tag("sql.query_type")) + } return t.finishedSpans } @@ -222,41 +224,30 @@ func (t *mocktracer) InjectWithOptions(context ddtrace.SpanContext, carrier inte } if cfg.TraceIDKey != "" { - writer.Set(cfg.TraceIDKey, strconv.FormatUint(ctx.traceID, 10)) + writer.Set(cfg.TraceIDKey, "test-trace-id") } if cfg.SpanIDKey != "" { - writer.Set(cfg.SpanIDKey, strconv.FormatUint(ctx.spanID, 10)) + writer.Set(cfg.SpanIDKey, "test-span-id") } if cfg.SamplingPriorityKey != "" { - if ctx.hasSamplingPriority() { - writer.Set(cfg.SamplingPriorityKey, strconv.Itoa(ctx.priority)) - } + writer.Set(cfg.SamplingPriorityKey, strconv.Itoa(ctx.priority)) } if cfg.EnvKey != "" { - envRaw := ctx.span.Tag(ext.Environment) - if env, ok := envRaw.(string); ok { - writer.Set(cfg.EnvKey, env) - } + writer.Set(cfg.EnvKey, "test-env") } if cfg.ParentVersionKey != "" { - versionRaw := ctx.span.Tag(ext.ParentVersion) - if version, ok := versionRaw.(string); ok { - writer.Set(cfg.ParentVersionKey, version) - } + writer.Set(cfg.ParentVersionKey, "v-test") } if cfg.ServiceNameKey != "" { - serviceNameRaw := ctx.span.Tag(ext.ServiceName) - if serviceName, ok := serviceNameRaw.(string); ok { - writer.Set(cfg.ServiceNameKey, serviceName) - } + writer.Set(cfg.ServiceNameKey, "test-service") } - sqlCommentCarrier, ok := carrier.(tracer.SQLCommentCarrier) + sqlCommentCarrier, ok := carrier.(*tracer.SQLCommentCarrier) if ok { // Save injected comments to assert the sql commenting behavior t.injectedComments = append(t.injectedComments, sqlCommentCarrier.CommentedQuery("")) diff --git a/ddtrace/tracer/sqlcomment.go b/ddtrace/tracer/sqlcomment.go index b2a198fb80..a8f04b9d51 100644 --- a/ddtrace/tracer/sqlcomment.go +++ b/ddtrace/tracer/sqlcomment.go @@ -15,6 +15,7 @@ type SQLCommentCarrier struct { tags map[string]string } +// Values for sql comment keys const ( SamplingPrioritySQLCommentKey = "ddsp" TraceIDSQLCommentKey = "ddtid" @@ -53,10 +54,14 @@ func commentWithTags(tags map[string]string) (comment string) { func (c *SQLCommentCarrier) CommentedQuery(query string) (commented string) { comment := commentWithTags(c.tags) - if comment == "" || query == "" { + if comment == "" { return query } + if query == "" { + return comment + } + return fmt.Sprintf("%s %s", comment, query) }