Skip to content

Commit

Permalink
wrr: use binary search to improve randomWRR performance
Browse files Browse the repository at this point in the history
  • Loading branch information
huangchong94 committed Dec 17, 2021
1 parent 51835dc commit 4dc83ca
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions internal/wrr/random.go
Expand Up @@ -19,15 +19,20 @@ package wrr

import (
"fmt"
"sort"
"sync"

"google.golang.org/grpc/internal/grpcrand"
)

// weightedItem is a wrapped weighted item that is used to implement weighted random algorithm.
type weightedItem struct {
Item interface{}
Weight int64
Item interface{}
// Delete Weight? This field is not necessary for randomWRR to work properly.
// But without this field, if we want to know an item's weight, we have to
// calculate it. weight = items.AccumulatedWeight - previousItem.AccumulatedWeight
Weight int64
AccumulatedWeight int64
}

func (w *weightedItem) String() string {
Expand All @@ -36,9 +41,8 @@ func (w *weightedItem) String() string {

// randomWRR is a struct that contains weighted items implement weighted random algorithm.
type randomWRR struct {
mu sync.RWMutex
items []*weightedItem
sumOfWeights int64
mu sync.RWMutex
items []*weightedItem
}

// NewRandom creates a new WRR with random.
Expand All @@ -51,27 +55,28 @@ var grpcrandInt63n = grpcrand.Int63n
func (rw *randomWRR) Next() (item interface{}) {
rw.mu.RLock()
defer rw.mu.RUnlock()
if rw.sumOfWeights == 0 {
sumOfWeights := rw.items[len(rw.items)-1].AccumulatedWeight
if sumOfWeights == 0 {
return nil
}
// Random number in [0, sum).
randomWeight := grpcrandInt63n(rw.sumOfWeights)
for _, item := range rw.items {
randomWeight = randomWeight - item.Weight
if randomWeight < 0 {
return item.Item
}
}

return rw.items[len(rw.items)-1].Item
// Random number in [0, sumOfWeights).
randomWeight := grpcrandInt63n(sumOfWeights)
// Item's accumulated weights are in ascending order, because item's weight >= 0.
// Binary search rw.items to find first item whose AccumulatedWeight > randomWeight
// The return i is guaranteed to be in range [0, len(rw.items)) because randomWeight < last item's AccumulatedWeight
i := sort.Search(len(rw.items), func(i int) bool { return rw.items[i].AccumulatedWeight > randomWeight })
return rw.items[i].Item
}

func (rw *randomWRR) Add(item interface{}, weight int64) {
rw.mu.Lock()
defer rw.mu.Unlock()
rItem := &weightedItem{Item: item, Weight: weight}
accumulatedWeight := weight
if len(rw.items) > 0 {
accumulatedWeight = rw.items[len(rw.items)-1].AccumulatedWeight + weight
}
rItem := &weightedItem{Item: item, Weight: weight, AccumulatedWeight: accumulatedWeight}
rw.items = append(rw.items, rItem)
rw.sumOfWeights += weight
}

func (rw *randomWRR) String() string {
Expand Down

0 comments on commit 4dc83ca

Please sign in to comment.