From 2cf5ecdb244879743394378dcb62fc5580ba7748 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Fri, 21 Aug 2020 17:40:22 +0800 Subject: [PATCH 1/2] Record {transaction,sample}.sample_rate Introduce the new ExtendedSampler interface, which Samplers may implement to return the effective sampling rate. This is implemented by the built-in ratioSampler. When starting a root transaction we now call the ExtendedSampler method if implemented, and store the effective sampling rate in the transaction's tracestate under our "es" vendor key. When receiving tracestate, we parse our "es" vendor value and extract the sample rate. When encoding transactions and spans we record the sample rate (from tracestate) in the transaction and span events. --- config.go | 2 + model/marshal_fastjson.go | 8 +++ model/model.go | 8 +++ modelwriter.go | 6 ++ module/apmgrpc/client_test.go | 2 +- module/apmot/harness_test.go | 3 +- sampler.go | 50 ++++++++++++++-- sampler_test.go | 20 +++++++ span_test.go | 22 +++++++ tracecontext.go | 78 +++++++++++++++++++++++- tracecontext_test.go | 11 ++++ tracer.go | 2 + transaction.go | 26 +++++++- transaction_test.go | 110 ++++++++++++++++++++++++++++++++++ utils_go10.go | 26 ++++++++ utils_go9.go | 33 ++++++++++ 16 files changed, 397 insertions(+), 10 deletions(-) create mode 100644 utils_go10.go create mode 100644 utils_go9.go diff --git a/config.go b/config.go index 5467b53e6..9e402ee6a 100644 --- a/config.go +++ b/config.go @@ -388,6 +388,7 @@ func (t *Tracer) updateRemoteConfig(logger WarningLogger, old, attrs map[string] } else { updates = append(updates, func(cfg *instrumentationConfig) { cfg.sampler = sampler + cfg.extendedSampler, _ = sampler.(ExtendedSampler) }) } default: @@ -479,6 +480,7 @@ type instrumentationConfigValues struct { recording bool captureBody CaptureBodyMode captureHeaders bool + extendedSampler ExtendedSampler maxSpans int sampler Sampler spanFramesMinDuration time.Duration diff --git a/model/marshal_fastjson.go b/model/marshal_fastjson.go index d4ea9d1fd..8572bb41c 100644 --- a/model/marshal_fastjson.go +++ b/model/marshal_fastjson.go @@ -390,6 +390,10 @@ func (v *Transaction) MarshalFastJSON(w *fastjson.Writer) error { w.RawString(",\"result\":") w.String(v.Result) } + if v.SampleRate != nil { + w.RawString(",\"sample_rate\":") + w.Float64(*v.SampleRate) + } if v.Sampled != nil { w.RawString(",\"sampled\":") w.Bool(*v.Sampled) @@ -445,6 +449,10 @@ func (v *Span) MarshalFastJSON(w *fastjson.Writer) error { firstErr = err } } + if v.SampleRate != nil { + w.RawString(",\"sample_rate\":") + w.Float64(*v.SampleRate) + } if v.Stacktrace != nil { w.RawString(",\"stacktrace\":") w.RawByte('[') diff --git a/model/model.go b/model/model.go index 3f06b85c6..e41823c21 100644 --- a/model/model.go +++ b/model/model.go @@ -205,6 +205,10 @@ type Transaction struct { // it to true. Sampled *bool `json:"sampled,omitempty"` + // SampleRate holds the sample rate in effect when the trace was started, + // if known. This is used by the server to aggregate transaction metrics. + SampleRate *float64 `json:"sample_rate,omitempty"` + // SpanCount holds statistics on spans within a transaction. SpanCount SpanCount `json:"span_count"` } @@ -254,6 +258,10 @@ type Span struct { // ParentID holds the ID of the span's parent (span or transaction). ParentID SpanID `json:"parent_id,omitempty"` + // SampleRate holds the sample rate in effect when the trace was started, + // if known. This is used by the server to aggregate span metrics. + SampleRate *float64 `json:"sample_rate,omitempty"` + // Context holds contextual information relating to the span. Context *SpanContext `json:"context,omitempty"` diff --git a/modelwriter.go b/modelwriter.go index e78d9be8f..fd327e099 100644 --- a/modelwriter.go +++ b/modelwriter.go @@ -109,6 +109,9 @@ func (w *modelWriter) buildModelTransaction(out *model.Transaction, tx *Transact if !sampled { out.Sampled = ¬Sampled } + if tx.traceContext.State.haveSampleRate { + out.SampleRate = &tx.traceContext.State.sampleRate + } out.ParentID = model.SpanID(td.parentSpan) out.Name = truncateString(td.Name) @@ -137,6 +140,9 @@ func (w *modelWriter) buildModelSpan(out *model.Span, span *Span, sd *SpanData) out.ID = model.SpanID(span.traceContext.Span) out.TraceID = model.TraceID(span.traceContext.Trace) out.TransactionID = model.SpanID(span.transactionID) + if span.traceContext.State.haveSampleRate { + out.SampleRate = &span.traceContext.State.sampleRate + } out.ParentID = model.SpanID(sd.parentID) out.Name = truncateString(sd.Name) diff --git a/module/apmgrpc/client_test.go b/module/apmgrpc/client_test.go index 89aae5475..ab098d7bd 100644 --- a/module/apmgrpc/client_test.go +++ b/module/apmgrpc/client_test.go @@ -93,7 +93,7 @@ func testClientSpan(t *testing.T, traceparentHeaders ...string) { } assert.Equal(t, clientSpans[0].TraceID, serverTransactions[1].TraceID) assert.Equal(t, clientSpans[0].ID, serverTransactions[1].ParentID) - assert.Equal(t, "server_span", serverSpans[0].Name) // no tracestate + assert.Equal(t, "es=s:1", serverSpans[0].Name) // automatically created tracestate assert.Equal(t, "vendor=tracestate", serverSpans[1].Name) traceparentValue := apmhttp.FormatTraceparentHeader(apm.TraceContext{ diff --git a/module/apmot/harness_test.go b/module/apmot/harness_test.go index 330979f08..2fc471bf4 100644 --- a/module/apmot/harness_test.go +++ b/module/apmot/harness_test.go @@ -115,5 +115,6 @@ func (harnessAPIProbe) SameSpanContext(span opentracing.Span, sc opentracing.Spa if !ok { return false } - return ctx1.traceContext == ctx2.traceContext + return ctx1.traceContext.Trace == ctx2.traceContext.Trace && + ctx1.traceContext.Span == ctx2.traceContext.Span } diff --git a/sampler.go b/sampler.go index 3cf4591c6..f3268e3ab 100644 --- a/sampler.go +++ b/sampler.go @@ -35,6 +35,37 @@ type Sampler interface { Sample(TraceContext) bool } +// ExtendedSampler may be implemented by Samplers, providing +// a method for sampling and returning an extended SampleResult. +// +// TODO(axw) in v2.0.0, replace the Sampler interface with this. +type ExtendedSampler interface { + // SampleExtended indicates whether or not a transaction + // should be sampled, and the sampling rate in effect at + // the time. This method will be invoked by calls to + // Tracer.StartTransaction for the root of a trace, so it + // must be goroutine-safe, and should avoid synchronization + // as far as possible. + SampleExtended(SampleParams) SampleResult +} + +// SampleParams holds parameters for SampleExtended. +type SampleParams struct { + // TraceContext holds the newly-generated TraceContext + // for the root transaction which is being sampled. + TraceContext TraceContext +} + +// SampleResult holds information about a sampling decision. +type SampleResult struct { + // Sampled holds the sampling decision. + Sampled bool + + // SampleRate holds the sample rate in effect at the + // time of the sampling decision. + SampleRate float64 +} + // NewRatioSampler returns a new Sampler with the given ratio // // A ratio of 1.0 samples 100% of transactions, a ratio of 0.5 @@ -51,16 +82,27 @@ func NewRatioSampler(r float64) Sampler { x.SetUint64(math.MaxUint64) x.Mul(&x, big.NewFloat(r)) ceil, _ := x.Uint64() - return ratioSampler{ceil} + return ratioSampler{r, ceil} } type ratioSampler struct { - ceil uint64 + ratio float64 + ceil uint64 } // Sample samples the transaction according to the configured // ratio and pseudo-random source. func (s ratioSampler) Sample(c TraceContext) bool { - v := binary.BigEndian.Uint64(c.Span[:]) - return v > 0 && v-1 < s.ceil + return s.SampleExtended(SampleParams{TraceContext: c}).Sampled +} + +// SampleExtended samples the transaction according to the configured +// ratio and pseudo-random source. +func (s ratioSampler) SampleExtended(args SampleParams) SampleResult { + v := binary.BigEndian.Uint64(args.TraceContext.Span[:]) + result := SampleResult{ + Sampled: v > 0 && v-1 < s.ceil, + SampleRate: s.ratio, + } + return result } diff --git a/sampler_test.go b/sampler_test.go index 4b57a6a76..9dea45f12 100644 --- a/sampler_test.go +++ b/sampler_test.go @@ -86,3 +86,23 @@ func TestRatioSamplerNever(t *testing.T) { Span: apm.SpanID{255, 255, 255, 255, 255, 255, 255, 255}, })) } + +func TestRatioSamplerExtended(t *testing.T) { + s := apm.NewRatioSampler(0.5).(apm.ExtendedSampler) + + result := s.SampleExtended(apm.SampleParams{ + TraceContext: apm.TraceContext{Span: apm.SpanID{255, 0, 0, 0, 0, 0, 0, 0}}, + }) + assert.Equal(t, apm.SampleResult{ + Sampled: false, + SampleRate: 0.5, + }, result) + + result = s.SampleExtended(apm.SampleParams{ + TraceContext: apm.TraceContext{Span: apm.SpanID{1, 0, 0, 0, 0, 0, 0, 0}}, + }) + assert.Equal(t, apm.SampleResult{ + Sampled: true, + SampleRate: 0.5, + }, result) +} diff --git a/span_test.go b/span_test.go index 118eaad8b..805e0d3e6 100644 --- a/span_test.go +++ b/span_test.go @@ -146,3 +146,25 @@ func TestTracerStartSpanIDSpecified(t *testing.T) { require.Len(t, spans, 1) assert.Equal(t, model.SpanID(spanID), spans[0].ID) } + +func TestSpanSampleRate(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + tracer.SetSampler(apm.NewRatioSampler(0.5555)) + + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + // Use a known transaction ID for deterministic sampling. + TransactionID: apm.SpanID{1, 2, 3, 4, 5, 6, 7, 8}, + }) + s1 := tx.StartSpan("name", "type", nil) + s2 := tx.StartSpan("name", "type", s1) + s2.End() + s1.End() + tx.End() + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, 0.556, *payloads.Transactions[0].SampleRate) + assert.Equal(t, 0.556, *payloads.Spans[0].SampleRate) + assert.Equal(t, 0.556, *payloads.Spans[1].SampleRate) +} diff --git a/tracecontext.go b/tracecontext.go index 2983e85d6..dc6ab86af 100644 --- a/tracecontext.go +++ b/tracecontext.go @@ -22,11 +22,17 @@ import ( "encoding/hex" "fmt" "regexp" + "strconv" + "strings" "unicode" "github.com/pkg/errors" ) +const ( + elasticTracestateVendorKey = "es" +) + var ( errZeroTraceID = errors.New("zero trace-id is invalid") errZeroSpanID = errors.New("zero span-id is invalid") @@ -152,6 +158,13 @@ func (o TraceOptions) WithRecorded(recorded bool) TraceOptions { // TraceState holds vendor-specific state for a trace. type TraceState struct { head *TraceStateEntry + + // Fields related to parsing the Elastic ("es") tracestate entry. + // + // These must not be modified after NewTraceState returns. + parseElasticTracestateError error + haveSampleRate bool + sampleRate float64 } // NewTraceState returns a TraceState based on entries. @@ -167,9 +180,55 @@ func NewTraceState(entries ...TraceStateEntry) TraceState { } last = &e } + for _, e := range entries { + if e.Key != elasticTracestateVendorKey { + continue + } + out.parseElasticTracestateError = out.parseElasticTracestate(e) + break + } return out } +// parseElasticTracestate parses an Elastic ("es") tracestate entry. +// +// Per https://github.com/elastic/apm/blob/master/specs/agents/tracing-distributed-tracing.md, +// the "es" tracestate value format is: "key:value;key:value...". Unknown keys are ignored. +func (s *TraceState) parseElasticTracestate(e TraceStateEntry) error { + if err := e.Validate(); err != nil { + return err + } + value := e.Value + for value != "" { + kv := value + end := strings.IndexRune(value, ';') + if end >= 0 { + kv = value[:end] + value = value[end+1:] + } else { + value = "" + } + sep := strings.IndexRune(kv, ':') + if sep == -1 { + return errors.New("malformed 'es' tracestate entry") + } + k, v := kv[:sep], kv[sep+1:] + switch k { + case "s": + sampleRate, err := strconv.ParseFloat(v, 64) + if err != nil { + return err + } + if sampleRate < 0 || sampleRate > 1 { + return fmt.Errorf("sample rate %q out of range", v) + } + s.sampleRate = sampleRate + s.haveSampleRate = true + } + } + return nil +} + // String returns s as a comma-separated list of key-value pairs. func (s TraceState) String() string { if s.head == nil { @@ -199,8 +258,16 @@ func (s TraceState) Validate() error { if i == 32 { return errors.New("tracestate contains more than the maximum allowed number of entries, 32") } - if err := e.Validate(); err != nil { - return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + if e.Key == elasticTracestateVendorKey { + // s.parseElasticTracestateError holds a general e.Validate error if any + // occurred, or any other error specific to the Elastic tracestate format. + if err := s.parseElasticTracestateError; err != nil { + return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + } + } else { + if err := e.Validate(); err != nil { + return errors.Wrapf(err, "invalid tracestate entry at position %d", i) + } } if prev, ok := recorded[e.Key]; ok { return fmt.Errorf("duplicate tracestate key %q at positions %d and %d", e.Key, prev, i) @@ -261,3 +328,10 @@ func (e *TraceStateEntry) validateValue() error { } return nil } + +func formatElasticTracestateValue(sampleRate float64) string { + // 0 -> "s:0" + // 1 -> "s:1" + // 0.5555 -> "s:0.555" (any rounding should be applied prior) + return fmt.Sprintf("s:%.3g", sampleRate) +} diff --git a/tracecontext_test.go b/tracecontext_test.go index 0b9dc9606..49a79068e 100644 --- a/tracecontext_test.go +++ b/tracecontext_test.go @@ -105,3 +105,14 @@ func TestTraceStateInvalidValueCharacter(t *testing.T) { `invalid tracestate entry at position 0: invalid value for key "oy": value contains invalid character '\x00'`) } } + +func TestTraceStateInvalidElasticEntry(t *testing.T) { + ts := apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "foo"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: malformed 'es' tracestate entry`) + + ts = apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:foo"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: strconv.ParseFloat: parsing "foo": invalid syntax`) + + ts = apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:1.5"}) + assert.EqualError(t, ts.Validate(), `invalid tracestate entry at position 0: sample rate "1.5" out of range`) +} diff --git a/tracer.go b/tracer.go index d9269dfbc..1256e3390 100644 --- a/tracer.go +++ b/tracer.go @@ -412,6 +412,7 @@ func newTracer(opts TracerOptions) *Tracer { }) t.setLocalInstrumentationConfig(envTransactionSampleRate, func(cfg *instrumentationConfigValues) { cfg.sampler = opts.sampler + cfg.extendedSampler, _ = opts.sampler.(ExtendedSampler) }) t.setLocalInstrumentationConfig(envSpanFramesMinDuration, func(cfg *instrumentationConfigValues) { cfg.spanFramesMinDuration = opts.spanFramesMinDuration @@ -664,6 +665,7 @@ func (t *Tracer) SetRecording(r bool) { func (t *Tracer) SetSampler(s Sampler) { t.setLocalInstrumentationConfig(envTransactionSampleRate, func(cfg *instrumentationConfigValues) { cfg.sampler = s + cfg.extendedSampler, _ = s.(ExtendedSampler) }) } diff --git a/transaction.go b/transaction.go index 5536919cd..e5617fbe5 100644 --- a/transaction.go +++ b/transaction.go @@ -97,8 +97,30 @@ func (t *Tracer) StartTransactionOptions(name, transactionType string, opts Tran } if root { - sampler := instrumentationConfig.sampler - if sampler == nil || sampler.Sample(tx.traceContext) { + var result SampleResult + if instrumentationConfig.extendedSampler != nil { + result = instrumentationConfig.extendedSampler.SampleExtended(SampleParams{ + TraceContext: tx.traceContext, + }) + if !result.Sampled { + // Special case: for unsampled transactions we + // report a sample rate of 0, so that we do not + // count them in aggregations in the server. + // This is necessary to avoid overcounting, as + // we will scale the sampled transactions. + result.SampleRate = 0 + } + sampleRate := round(1000*result.SampleRate) / 1000 + tx.traceContext.State = NewTraceState(TraceStateEntry{ + Key: elasticTracestateVendorKey, + Value: formatElasticTracestateValue(sampleRate), + }) + } else if instrumentationConfig.sampler != nil { + result.Sampled = instrumentationConfig.sampler.Sample(tx.traceContext) + } else { + result.Sampled = true + } + if result.Sampled { o := tx.traceContext.Options.WithRecorded(true) tx.traceContext.Options = o } diff --git a/transaction_test.go b/transaction_test.go index 4692ff0d3..291efd477 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -175,6 +175,116 @@ func TestTransactionNotRecording(t *testing.T) { require.Empty(t, payloads.Transactions) } +func TestTransactionSampleRate(t *testing.T) { + type test struct { + actualSampleRate float64 + recordedSampleRate float64 + expectedTraceState string + } + tests := []test{ + {0, 0, "es=s:0"}, + {1, 1, "es=s:1"}, + {0.5555, 0.556, "es=s:0.556"}, + } + for _, test := range tests { + test := test // copy for closure + t.Run(fmt.Sprintf("%v", test.actualSampleRate), func(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + tracer.SetSampler(apm.NewRatioSampler(test.actualSampleRate)) + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + // Use a known transaction ID for deterministic sampling. + TransactionID: apm.SpanID{1, 2, 3, 4, 5, 6, 7, 8}, + }) + tx.End() + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, test.recordedSampleRate, *payloads.Transactions[0].SampleRate) + assert.Equal(t, test.expectedTraceState, tx.TraceContext().State.String()) + }) + } +} + +func TestTransactionUnsampledSampleRate(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + tracer.SetSampler(apm.NewRatioSampler(0)) + + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{}) + tx.End() + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Equal(t, float64(0), *payloads.Transactions[0].SampleRate) + assert.Equal(t, "es=s:0", tx.TraceContext().State.String()) +} + +func TestTransactionSampleRatePropagation(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + for _, tracestate := range []apm.TraceState{ + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:0.5"}), + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "x:y;s:0.5;zz:y"}), + apm.NewTraceState( + apm.TraceStateEntry{Key: "other", Value: "s:1.0"}, + apm.TraceStateEntry{Key: "es", Value: "s:0.5"}, + ), + } { + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + TraceContext: apm.TraceContext{ + Trace: apm.TraceID{1}, + Span: apm.SpanID{1}, + State: tracestate, + }, + }) + tx.End() + } + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Len(t, payloads.Transactions, 3) + for _, tx := range payloads.Transactions { + assert.Equal(t, 0.5, *tx.SampleRate) + } +} + +func TestTransactionSampleRateOmission(t *testing.T) { + tracer := apmtest.NewRecordingTracer() + defer tracer.Close() + + // For downstream transactions, sample_rate should be + // omitted if a valid value is not found in tracestate. + for _, tracestate := range []apm.TraceState{ + apm.TraceState{}, // empty + apm.NewTraceState(apm.TraceStateEntry{Key: "other", Value: "s:1.0"}), // not "es", ignored + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "s:123.0"}), // out of range + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: ""}), // 's' missing + apm.NewTraceState(apm.TraceStateEntry{Key: "es", Value: "wat"}), // malformed + } { + for _, sampled := range []bool{false, true} { + tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{ + TraceContext: apm.TraceContext{ + Trace: apm.TraceID{1}, + Span: apm.SpanID{1}, + Options: apm.TraceOptions(0).WithRecorded(sampled), + State: tracestate, + }, + }) + tx.End() + } + } + tracer.Flush(nil) + + payloads := tracer.Payloads() + assert.Len(t, payloads.Transactions, 10) + for _, tx := range payloads.Transactions { + assert.Nil(t, tx.SampleRate) + } +} + func BenchmarkTransaction(b *testing.B) { tracer, err := apm.NewTracer("service", "") require.NoError(b, err) diff --git a/utils_go10.go b/utils_go10.go new file mode 100644 index 000000000..d2c1bbfc4 --- /dev/null +++ b/utils_go10.go @@ -0,0 +1,26 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// +build go1.10 + +package apm + +import "math" + +func round(x float64) float64 { + return math.Round(x) +} diff --git a/utils_go9.go b/utils_go9.go new file mode 100644 index 000000000..6ff880977 --- /dev/null +++ b/utils_go9.go @@ -0,0 +1,33 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// +build !go1.10 + +package apm + +import "math" + +// Implementation of math.Round for Go < 1.10. +// +// Code shamelessly copied from pkg/math. +func round(x float64) float64 { + t := math.Trunc(x) + if math.Abs(x-t) >= 0.5 { + return t + math.Copysign(1, x) + } + return t +} From 690f23a6a73dba8af2e3e36d61c3681ab25c9076 Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Thu, 3 Sep 2020 17:19:44 +0800 Subject: [PATCH 2/2] Improve test for special case --- transaction_test.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/transaction_test.go b/transaction_test.go index 291efd477..fa0813ca1 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -210,10 +210,22 @@ func TestTransactionSampleRate(t *testing.T) { func TestTransactionUnsampledSampleRate(t *testing.T) { tracer := apmtest.NewRecordingTracer() defer tracer.Close() - tracer.SetSampler(apm.NewRatioSampler(0)) - - tx := tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{}) - tx.End() + tracer.SetSampler(apm.NewRatioSampler(0.5)) + + // Create transactions until we get an unsampled one. + // + // Even though the configured sampling rate is 0.5, + // we record sample_rate=0 to ensure the server does + // not count the transaction toward metrics. + var tx *apm.Transaction + for { + tx = tracer.StartTransactionOptions("name", "type", apm.TransactionOptions{}) + if !tx.Sampled() { + tx.End() + break + } + tx.Discard() + } tracer.Flush(nil) payloads := tracer.Payloads()