diff --git a/statsd/worker.go b/statsd/worker.go index 3fdf3634..4741c5ac 100644 --- a/statsd/worker.go +++ b/statsd/worker.go @@ -3,12 +3,14 @@ package statsd import ( "math/rand" "sync" + "time" ) type worker struct { pool *bufferPool buffer *statsdBuffer sender *sender + random *rand.Rand sync.Mutex inputMetrics chan metric @@ -16,10 +18,20 @@ type worker struct { } func newWorker(pool *bufferPool, sender *sender) *worker { + // Each worker uses its own random source to prevent workers in separate + // goroutines from contending for the lock on the "math/rand" package-global + // random source (e.g. calls like "rand.Float64()" must acquire a shared + // lock to get the next pseudorandom number). + // Note that calling "time.Now().UnixNano()" repeatedly quickly may return + // very similar values. That's fine for seeding the worker-specific random + // source because we just need an evenly distributed stream of float values. + // Do not use this random source for cryptographic randomness. + random := rand.New(rand.NewSource(time.Now().UnixNano())) return &worker{ pool: pool, sender: sender, buffer: pool.borrowBuffer(), + random: random, stop: make(chan struct{}), } } @@ -59,7 +71,7 @@ func (w *worker) processMetric(m metric) error { } func (w *worker) shouldSample(rate float64) bool { - if rate < 1 && rand.Float64() > rate { + if rate < 1 && w.random.Float64() > rate { return false } return true diff --git a/statsd/worker_test.go b/statsd/worker_test.go new file mode 100644 index 00000000..58a8d881 --- /dev/null +++ b/statsd/worker_test.go @@ -0,0 +1,38 @@ +package statsd + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShouldSample(t *testing.T) { + rates := []float64{0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 1.0} + iterations := 50_000 + + for _, rate := range rates { + rate := rate // Capture range variable. + t.Run(fmt.Sprintf("Rate %0.2f", rate), func(t *testing.T) { + t.Parallel() + + worker := newWorker(newBufferPool(1, 1, 1), nil) + count := 0 + for i := 0; i < iterations; i++ { + if worker.shouldSample(rate) { + count++ + } + } + assert.InDelta(t, rate, float64(count)/float64(iterations), 0.01) + }) + } +} + +func BenchmarkShouldSample(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + worker := newWorker(newBufferPool(1, 1, 1), nil) + for pb.Next() { + worker.shouldSample(0.1) + } + }) +}