diff --git a/config.go b/config.go index 5467b53e6..9e402ee6a 100644 --- a/config.go +++ b/config.go @@ -388,6 +388,7 @@ func (t *Tracer) updateRemoteConfig(logger WarningLogger, old, attrs map[string] } else { updates = append(updates, func(cfg *instrumentationConfig) { cfg.sampler = sampler + cfg.extendedSampler, _ = sampler.(ExtendedSampler) }) } default: @@ -479,6 +480,7 @@ type instrumentationConfigValues struct { recording bool captureBody CaptureBodyMode captureHeaders bool + extendedSampler ExtendedSampler maxSpans int sampler Sampler spanFramesMinDuration time.Duration diff --git a/model/marshal_fastjson.go b/model/marshal_fastjson.go index d4ea9d1fd..8572bb41c 100644 --- a/model/marshal_fastjson.go +++ b/model/marshal_fastjson.go @@ -390,6 +390,10 @@ func (v *Transaction) MarshalFastJSON(w *fastjson.Writer) error { w.RawString(",\"result\":") w.String(v.Result) } + if v.SampleRate != nil { + w.RawString(",\"sample_rate\":") + w.Float64(*v.SampleRate) + } if v.Sampled != nil { w.RawString(",\"sampled\":") w.Bool(*v.Sampled) @@ -445,6 +449,10 @@ func (v *Span) MarshalFastJSON(w *fastjson.Writer) error { firstErr = err } } + if v.SampleRate != nil { + w.RawString(",\"sample_rate\":") + w.Float64(*v.SampleRate) + } if v.Stacktrace != nil { w.RawString(",\"stacktrace\":") w.RawByte('[') diff --git a/model/model.go b/model/model.go index 3f06b85c6..e41823c21 100644 --- a/model/model.go +++ b/model/model.go @@ -205,6 +205,10 @@ type Transaction struct { // it to true. Sampled *bool `json:"sampled,omitempty"` + // SampleRate holds the sample rate in effect when the trace was started, + // if known. This is used by the server to aggregate transaction metrics. + SampleRate *float64 `json:"sample_rate,omitempty"` + // SpanCount holds statistics on spans within a transaction. SpanCount SpanCount `json:"span_count"` } @@ -254,6 +258,10 @@ type Span struct { // ParentID holds the ID of the span's parent (span or transaction). ParentID SpanID `json:"parent_id,omitempty"` + // SampleRate holds the sample rate in effect when the trace was started, + // if known. This is used by the server to aggregate span metrics. + SampleRate *float64 `json:"sample_rate,omitempty"` + // Context holds contextual information relating to the span. Context *SpanContext `json:"context,omitempty"` diff --git a/modelwriter.go b/modelwriter.go index e78d9be8f..fd327e099 100644 --- a/modelwriter.go +++ b/modelwriter.go @@ -109,6 +109,9 @@ func (w *modelWriter) buildModelTransaction(out *model.Transaction, tx *Transact if !sampled { out.Sampled = ¬Sampled } + if tx.traceContext.State.haveSampleRate { + out.SampleRate = &tx.traceContext.State.sampleRate + } out.ParentID = model.SpanID(td.parentSpan) out.Name = truncateString(td.Name) @@ -137,6 +140,9 @@ func (w *modelWriter) buildModelSpan(out *model.Span, span *Span, sd *SpanData) out.ID = model.SpanID(span.traceContext.Span) out.TraceID = model.TraceID(span.traceContext.Trace) out.TransactionID = model.SpanID(span.transactionID) + if span.traceContext.State.haveSampleRate { + out.SampleRate = &span.traceContext.State.sampleRate + } out.ParentID = model.SpanID(sd.parentID) out.Name = truncateString(sd.Name) diff --git a/module/apmgrpc/client_test.go b/module/apmgrpc/client_test.go index 89aae5475..ab098d7bd 100644 --- a/module/apmgrpc/client_test.go +++ b/module/apmgrpc/client_test.go @@ -93,7 +93,7 @@ func testClientSpan(t *testing.T, traceparentHeaders ...string) { } assert.Equal(t, clientSpans[0].TraceID, serverTransactions[1].TraceID) assert.Equal(t, clientSpans[0].ID, serverTransactions[1].ParentID) - assert.Equal(t, "server_span", serverSpans[0].Name) // no tracestate + assert.Equal(t, "es=s:1", serverSpans[0].Name) // automatically created tracestate assert.Equal(t, "vendor=tracestate", serverSpans[1].Name) traceparentValue := apmhttp.FormatTraceparentHeader(apm.TraceContext{ diff --git a/module/apmot/harness_test.go b/module/apmot/harness_test.go index 330979f08..2fc471bf4 100644 --- a/module/apmot/harness_test.go +++ b/module/apmot/harness_test.go @@ -115,5 +115,6 @@ func (harnessAPIProbe) SameSpanContext(span opentracing.Span, sc opentracing.Spa if !ok { return false } - return ctx1.traceContext == ctx2.traceContext + return ctx1.traceContext.Trace == ctx2.traceContext.Trace && + ctx1.traceContext.Span == ctx2.traceContext.Span } diff --git a/sampler.go b/sampler.go index 3cf4591c6..f3268e3ab 100644 --- a/sampler.go +++ b/sampler.go @@ -35,6 +35,37 @@ type Sampler interface { Sample(TraceContext) bool } +// ExtendedSampler may be implemented by Samplers, providing +// a method for sampling and returning an extended SampleResult. +// +// TODO(axw) in v2.0.0, replace the Sampler interface with this. +type ExtendedSampler interface { + // SampleExtended indicates whether or not a transaction + // should be sampled, and the sampling rate in effect at + // the time. This method will be invoked by calls to + // Tracer.StartTransaction for the root of a trace, so it + // must be goroutine-safe, and should avoid synchronization + // as far as possible. + SampleExtended(SampleParams) SampleResult +} + +// SampleParams holds parameters for SampleExtended. +type SampleParams struct { + // TraceContext holds the newly-generated TraceContext + // for the root transaction which is being sampled. + TraceContext TraceContext +} + +// SampleResult holds information about a sampling decision. +type SampleResult struct { + // Sampled holds the sampling decision. + Sampled bool + + // SampleRate holds the sample rate in effect at the + // time of the sampling decision. + SampleRate float64 +} + // NewRatioSampler returns a new Sampler with the given ratio // // A ratio of 1.0 samples 100% of transactions, a ratio of 0.5 @@ -51,16 +82,27 @@ func NewRatioSampler(r float64) Sampler { x.SetUint64(math.MaxUint64) x.Mul(&x, big.NewFloat(r)) ceil, _ := x.Uint64() - return ratioSampler{ceil} + return ratioSampler{r, ceil} } type ratioSampler struct { - ceil uint64 + ratio float64 + ceil uint64 } // Sample samples the transaction according to the configured // ratio and pseudo-random source. func (s ratioSampler) Sample(c TraceContext) bool { - v := binary.BigEndian.Uint64(c.Span[:]) - return v > 0 && v-1 < s.ceil + return s.SampleExtended(SampleParams{TraceContext: c}).Sampled +} + +// SampleExtended samples the transaction according to the configured +// ratio and pseudo-random source. +func (s ratioSampler) SampleExtended(args SampleParams) SampleResult { + v := binary.BigEndian.Uint64(args.TraceContext.Span[:]) + result := SampleResult{ + Sampled: v > 0 && v-1 < s.ceil, + SampleRate: s.ratio, + } + return result } diff --git a/sampler_test.go b/sampler_test.go index 4b57a6a76..9dea45f12 100644 --- a/sampler_test.go +++ b/sampler_test.go @@ -86,3 +86,23 @@ func TestRatioSamplerNever(t *testing.T) { Span: apm.SpanID{255, 255, 255, 255, 255, 255, 255, 255}, })) } + +func TestRatioSamplerExtended(t *testing.T) { + s := apm.NewRatioSampler(0.5).(apm.ExtendedSampler) + + result := s.SampleExtended(apm.SampleParams{ + TraceContext: apm.TraceContext{Span: apm.SpanID{255, 0, 0, 0, 0, 0, 0, 0}}, + }) + assert.Equal(t, apm.SampleResult{ + Sampled: false, + SampleRate: 0.5, + }, result) + + result = s.SampleExtended(apm.SampleParams{ + TraceContext: apm.TraceContext{Span: apm.SpanID{1, 0, 0, 0, 0, 0, 0, 0}}, + }) + assert.Equal(t, apm.SampleResult{ + Sampled: true, + SampleRate: 0.5, + }, result) +} diff --git a/span_test.go b/span_test.go index 118eaad8b..805e0d3e6 100644 --- a/span_test.go +++ b/span_test.go @@ -146,3 +146,25 @@ func TestTracerStartSpanIDSpecified(t *testing.T) { require.Len(t, spans, 1) assert.Equal(t, model.SpanID(spanID), spans[0].ID) } + +func TestSpanSampleRate(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + tracer.SetSampler(apm.NewRatioSampler(0.5555)) + + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + // Use a known transaction ID for deterministic sampling. + TransactionID: apm.SpanID{1, 2, 3, 4, 5, 6, 7, 8}, + }) + s1 := tx.StartSpan("name", "type", nil) + s2 := tx.StartSpan("name", "type", s1) + s2.End() + s1.End() + tx.End() + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, 0.556, *payloads.Transactions[0].SampleRate) + assert.Equal(t, 0.556, *payloads.Spans[0].SampleRate) + assert.Equal(t, 0.556, *payloads.Spans[1].SampleRate) +} diff --git a/tracecontext.go b/tracecontext.go index 2983e85d6..dc6ab86af 100644 --- a/tracecontext.go +++ b/tracecontext.go @@ -22,11 +22,17 @@ import ( "encoding/hex" "fmt" "regexp" + "strconv" + "strings" "unicode" "github.com/pkg/errors" ) +const ( + elasticTracestateVendorKey = "es" +) + var ( errZeroTraceID = errors.New("zero trace-id is invalid") errZeroSpanID = errors.New("zero span-id is invalid") @@ -152,6 +158,13 @@ func (o TraceOptions) WithRecorded(recorded bool) TraceOptions { // TraceState holds vendor-specific state for a trace. type TraceState struct { head *TraceStateEntry + + // Fields related to parsing the Elastic ("es") tracestate entry. + // + // These must not be modified after NewTraceState returns. + parseElasticTracestateError error + haveSampleRate bool + sampleRate float64 } // NewTraceState returns a TraceState based on entries. @@ -167,9 +180,55 @@ func NewTraceState(entries ...TraceStateEntry) TraceState { } last = &e } + for _, e := range entries { + if e.Key != elasticTracestateVendorKey { + continue + } + out.parseElasticTracestateError = out.parseElasticTracestate(e) + break + } return out } +// parseElasticTracestate parses an Elastic ("es") tracestate entry. +// +// Per https://github.com/elastic/apm/blob/master/specs/agents/tracing-distributed-tracing.md, +// the "es" tracestate value format is: "key:value;key:value...". Unknown keys are ignored. +func (s *TraceState) parseElasticTracestate(e TraceStateEntry) error { + if err := e.Validate(); err != nil { + return err + } + value := e.Value + for value != "" { + kv := value + end := strings.IndexRune(value, ';') + if end >= 0 { + kv = value[:end] + value = value[end+1:] + } else { + value = "" + } + sep := strings.IndexRune(kv, ':') + if sep == -1 { + return errors.New("malformed 'es' tracestate entry") + } + k, v := kv[:sep], kv[sep+1:] + switch k { + case "s": + sampleRate, err := strconv.ParseFloat(v, 64) + if err != nil { + return err + } + if sampleRate < 0 || sampleRate > 1 { + return fmt.Errorf("sample rate %q out of range", v) + } + s.sampleRate = sampleRate + s.haveSampleRate = true + } + } + return nil +} + // String returns s as a comma-separated list of key-value pairs. func (s TraceState) String() string { if s.head == nil { @@ -199,8 +258,16 @@ func (s TraceState) Validate() error { if i == 32 { return errors.New("tracestate contains more than the maximum allowed number of entries, 32") } - if err := e.Validate(); err != nil { - return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + if e.Key == elasticTracestateVendorKey { + // s.parseElasticTracestateError holds a general e.Validate error if any + // occurred, or any other error specific to the Elastic tracestate format. + if err := s.parseElasticTracestateError; err != nil { + return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + } + } else { + if err := e.Validate(); err != nil { + return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + } } if prev, ok := recorded[e.Key]; ok { return fmt.Errorf("duplicate tracestate key %q at positions %d and %d", e.Key, prev, i) @@ -261,3 +328,10 @@ func (e *TraceStateEntry) validateValue() error { } return nil } + +func formatElasticTracestateValue(sampleRate float64) string { + // 0 -> "s:0" + // 1 -> "s:1" + // 0.5555 -> "s:0.555" (any rounding should be applied prior) + return fmt.Sprintf("s:%.3g", sampleRate) +} diff --git a/tracecontext_test.go b/tracecontext_test.go index 0b9dc9606..49a79068e 100644 --- a/tracecontext_test.go +++ b/tracecontext_test.go @@ -105,3 +105,14 @@ func TestTraceStateInvalidValueCharacter(t *testing.T) { `invalid tracestate entry at position 0: invalid value for key "oy": value contains invalid character '\x00'`) } } + +func TestTraceStateInvalidElasticEntry(t *testing.T) { + ts := apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "foo"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: malformed 'es' tracestate entry`) + + ts = apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:foo"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: strconv.ParseFloat: parsing "foo": invalid syntax`) + + ts = apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:1.5"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: sample rate "1.5" out of range`) +} diff --git a/tracer.go b/tracer.go index d9269dfbc..1256e3390 100644 --- a/tracer.go +++ b/tracer.go @@ -412,6 +412,7 @@ func newTracer(opts TracerOptions) *Tracer { }) t.setLocalInstrumentationConfig(envTransactionSampleRate, func(cfg *instrumentationConfigValues) { cfg.sampler = opts.sampler + cfg.extendedSampler, _ = opts.sampler.(ExtendedSampler) }) t.setLocalInstrumentationConfig(envSpanFramesMinDuration, func(cfg *instrumentationConfigValues) { cfg.spanFramesMinDuration = opts.spanFramesMinDuration @@ -664,6 +665,7 @@ func (t *Tracer) SetRecording(r bool) { func (t *Tracer) SetSampler(s Sampler) { t.setLocalInstrumentationConfig(envTransactionSampleRate, func(cfg *instrumentationConfigValues) { cfg.sampler = s + cfg.extendedSampler, _ = s.(ExtendedSampler) }) } diff --git a/transaction.go b/transaction.go index 5536919cd..e5617fbe5 100644 --- a/transaction.go +++ b/transaction.go @@ -97,8 +97,30 @@ func (t *Tracer) StartTransactionOptions(name, transactionType string, opts Tran } if root { - sampler := instrumentationConfig.sampler - if sampler == nil || sampler.Sample(tx.traceContext) { + var result SampleResult + if instrumentationConfig.extendedSampler != nil { + result = instrumentationConfig.extendedSampler.SampleExtended(SampleParams{ + TraceContext: tx.traceContext, + }) + if !result.Sampled { + // Special case: for unsampled transactions we + // report a sample rate of 0, so that we do not + // count them in aggregations in the server. + // This is necessary to avoid overcounting, as + // we will scale the sampled transactions. + result.SampleRate = 0 + } + sampleRate := round(1000*result.SampleRate) / 1000 + tx.traceContext.State = NewTraceState(TraceStateEntry{ + Key: elasticTracestateVendorKey, + Value: formatElasticTracestateValue(sampleRate), + }) + } else if instrumentationConfig.sampler != nil { + result.Sampled = instrumentationConfig.sampler.Sample(tx.traceContext) + } else { + result.Sampled = true + } + if result.Sampled { o := tx.traceContext.Options.WithRecorded(true) tx.traceContext.Options = o } diff --git a/transaction_test.go b/transaction_test.go index 4692ff0d3..fa0813ca1 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -175,6 +175,128 @@ func TestTransactionNotRecording(t *testing.T) { require.Empty(t, payloads.Transactions) } +func TestTransactionSampleRate(t *testing.T) { + type test struct { + actualSampleRate float64 + recordedSampleRate float64 + expectedTraceState string + } + tests := []test{ + {0, 0, "es=s:0"}, + {1, 1, "es=s:1"}, + {0.5555, 0.556, "es=s:0.556"}, + } + for _, test := range tests { + test := test // copy for closure + t.Run(fmt.Sprintf("%v", test.actualSampleRate), func(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + tracer.SetSampler(apm.NewRatioSampler(test.actualSampleRate)) + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + // Use a known transaction ID for deterministic sampling. + TransactionID: apm.SpanID{1, 2, 3, 4, 5, 6, 7, 8}, + }) + tx.End() + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, test.recordedSampleRate, *payloads.Transactions[0].SampleRate) + assert.Equal(t, test.expectedTraceState, tx.TraceContext().State.String()) + }) + } +} + +func TestTransactionUnsampledSampleRate(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + tracer.SetSampler(apm.NewRatioSampler(0.5)) + + // Create transactions until we get an unsampled one. + // + // Even though the configured sampling rate is 0.5, + // we record sample_rate=0 to ensure the server does + // not count the transaction toward metrics. + var tx *apm.Transaction + for { + tx = tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{}) + if !tx.Sampled() { + tx.End() + break + } + tx.Discard() + } + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, float64(0), *payloads.Transactions[0].SampleRate) + assert.Equal(t, "es=s:0", tx.TraceContext().State.String()) +} + +func TestTransactionSampleRatePropagation(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + for _, tracestate := range []apm.TraceState{ + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:0.5"}), + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "x:y;s:0.5;zz:y"}), + apm.NewTraceState( + apm.TraceStateEntry{Key: "other", Value: "s:1.0"}, + apm.TraceStateEntry{Key: "es", Value: "s:0.5"}, + ), + } { + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + TraceContext: apm.TraceContext{ + Trace: apm.TraceID{1}, + Span: apm.SpanID{1}, + State: tracestate, + }, + }) + tx.End() + } + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Len(t, payloads.Transactions, 3) + for _, tx := range payloads.Transactions { + assert.Equal(t, 0.5, *tx.SampleRate) + } +} + +func TestTransactionSampleRateOmission(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + // For downstream transactions, sample_rate should be + // omitted if a valid value is not found in tracestate. + for _, tracestate := range []apm.TraceState{ + apm.TraceState{}, // empty + apm.NewTraceState(apm.TraceStateEntry{Key: "other", Value: "s:1.0"}), // not "es", ignored + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:123.0"}), // out of range + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: ""}), // 's' missing + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "wat"}), // malformed + } { + for _, sampled := range []bool{false, true} { + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + TraceContext: apm.TraceContext{ + Trace: apm.TraceID{1}, + Span: apm.SpanID{1}, + Options: apm.TraceOptions(0).WithRecorded(sampled), + State: tracestate, + }, + }) + tx.End() + } + } + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Len(t, payloads.Transactions, 10) + for _, tx := range payloads.Transactions { + assert.Nil(t, tx.SampleRate) + } +} + func BenchmarkTransaction(b *testing.B) { tracer, err := apm.NewTracer("service", "") require.NoError(b, err) diff --git a/utils_go10.go b/utils_go10.go new file mode 100644 index 000000000..d2c1bbfc4 --- /dev/null +++ b/utils_go10.go @@ -0,0 +1,26 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// +build go1.10 + +package apm + +import "math" + +func round(x float64) float64 { + return math.Round(x) +} diff --git a/utils_go9.go b/utils_go9.go new file mode 100644 index 000000000..6ff880977 --- /dev/null +++ b/utils_go9.go @@ -0,0 +1,33 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// +build !go1.10 + +package apm + +import "math" + +// Implementation of math.Round for Go < 1.10. +// +// Code shamelessly copied from pkg/math. +func round(x float64) float64 { + t := math.Trunc(x) + if math.Abs(x-t) >= 0.5 { + return t + math.Copysign(1, x) + } + return t +}