Skip to content

Commit

Permalink
wrr: improve randomWRR performance (#5067)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangchong94 committed Jan 12, 2022
1 parent 0145b50 commit f231ac5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 18 deletions.
44 changes: 28 additions & 16 deletions internal/wrr/random.go
Expand Up @@ -19,15 +19,17 @@ package wrr

import (
"fmt"
"sort"
"sync"

"google.golang.org/grpc/internal/grpcrand"
)

// weightedItem is a wrapped weighted item that is used to implement weighted random algorithm.
type weightedItem struct {
Item interface{}
Weight int64
item interface{}
weight int64
accumulatedWeight int64
}

func (w *weightedItem) String() string {
Expand All @@ -36,9 +38,10 @@ func (w *weightedItem) String() string {

// randomWRR is a struct that contains weighted items implement weighted random algorithm.
type randomWRR struct {
mu sync.RWMutex
items []*weightedItem
sumOfWeights int64
mu sync.RWMutex
items []*weightedItem
// Are all item's weights equal
equalWeights bool
}

// NewRandom creates a new WRR with random.
Expand All @@ -51,27 +54,36 @@ var grpcrandInt63n = grpcrand.Int63n
func (rw *randomWRR) Next() (item interface{}) {
rw.mu.RLock()
defer rw.mu.RUnlock()
if rw.sumOfWeights == 0 {
if len(rw.items) == 0 {
return nil
}
// Random number in [0, sum).
randomWeight := grpcrandInt63n(rw.sumOfWeights)
for _, item := range rw.items {
randomWeight = randomWeight - item.Weight
if randomWeight < 0 {
return item.Item
}
if rw.equalWeights {
return rw.items[grpcrandInt63n(int64(len(rw.items)))].item
}

return rw.items[len(rw.items)-1].Item
sumOfWeights := rw.items[len(rw.items)-1].accumulatedWeight
// Random number in [0, sumOfWeights).
randomWeight := grpcrandInt63n(sumOfWeights)
// Item's accumulated weights are in ascending order, because item's weight >= 0.
// Binary search rw.items to find first item whose accumulatedWeight > randomWeight
// The return i is guaranteed to be in range [0, len(rw.items)) because randomWeight < last item's accumulatedWeight
i := sort.Search(len(rw.items), func(i int) bool { return rw.items[i].accumulatedWeight > randomWeight })
return rw.items[i].item
}

func (rw *randomWRR) Add(item interface{}, weight int64) {
rw.mu.Lock()
defer rw.mu.Unlock()
rItem := &weightedItem{Item: item, Weight: weight}
accumulatedWeight := weight
equalWeights := true
if len(rw.items) > 0 {
lastItem := rw.items[len(rw.items)-1]
accumulatedWeight = lastItem.accumulatedWeight + weight
equalWeights = rw.equalWeights && weight == lastItem.weight
}
rw.equalWeights = equalWeights
rItem := &weightedItem{item: item, weight: weight, accumulatedWeight: accumulatedWeight}
rw.items = append(rw.items, rItem)
rw.sumOfWeights += weight
}

func (rw *randomWRR) String() string {
Expand Down
79 changes: 77 additions & 2 deletions internal/wrr/wrr_test.go
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"math"
"math/rand"
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -70,12 +71,22 @@ func testWRRNext(t *testing.T, newWRR func() WRR) {
name: "17-23-37",
weights: []int64{17, 23, 37},
},
{
name: "no items",
weights: []int64{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var sumOfWeights int64

w := newWRR()
if len(tt.weights) == 0 {
if next := w.Next(); next != nil {
t.Fatalf("w.Next returns non nil value:%v when there is no item", next)
}
return
}

var sumOfWeights int64
for i, weight := range tt.weights {
w.Add(i, weight)
sumOfWeights += weight
Expand Down Expand Up @@ -112,6 +123,70 @@ func (s) TestEdfWrrNext(t *testing.T) {
testWRRNext(t, NewEDF)
}

func BenchmarkRandomWRRNext(b *testing.B) {
for _, n := range []int{100, 500, 1000} {
b.Run("equal-weights-"+strconv.Itoa(n)+"-items", func(b *testing.B) {
w := NewRandom()
sumOfWeights := n
for i := 0; i < n; i++ {
w.Add(i, 1)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := 0; i < sumOfWeights; i++ {
w.Next()
}
}
})
}

var maxWeight int64 = 1024
for _, n := range []int{100, 500, 1000} {
b.Run("random-weights-"+strconv.Itoa(n)+"-items", func(b *testing.B) {
w := NewRandom()
var sumOfWeights int64
for i := 0; i < n; i++ {
weight := rand.Int63n(maxWeight + 1)
w.Add(i, weight)
sumOfWeights += weight
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := 0; i < int(sumOfWeights); i++ {
w.Next()
}
}
})
}

itemsNum := 200
heavyWeight := int64(itemsNum)
lightWeight := int64(1)
heavyIndices := []int{0, itemsNum / 2, itemsNum - 1}
for _, heavyIndex := range heavyIndices {
b.Run("skew-weights-heavy-index-"+strconv.Itoa(heavyIndex), func(b *testing.B) {
w := NewRandom()
var sumOfWeights int64
for i := 0; i < itemsNum; i++ {
var weight int64
if i == heavyIndex {
weight = heavyWeight
} else {
weight = lightWeight
}
sumOfWeights += weight
w.Add(i, weight)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := 0; i < int(sumOfWeights); i++ {
w.Next()
}
}
})
}
}

func init() {
r := rand.New(rand.NewSource(0))
grpcrandInt63n = r.Int63n
Expand Down

0 comments on commit f231ac5

Please sign in to comment.