diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 8a5b89bc39b..5356194c340 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -25,6 +25,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/resolver" ) @@ -35,62 +36,14 @@ type scStateUpdate struct { state connectivity.State } -// scStateUpdateBuffer is an unbounded channel for scStateChangeTuple. -// TODO make a general purpose buffer that uses interface{}. -type scStateUpdateBuffer struct { - c chan *scStateUpdate - mu sync.Mutex - backlog []*scStateUpdate -} - -func newSCStateUpdateBuffer() *scStateUpdateBuffer { - return &scStateUpdateBuffer{ - c: make(chan *scStateUpdate, 1), - } -} - -func (b *scStateUpdateBuffer) put(t *scStateUpdate) { - b.mu.Lock() - defer b.mu.Unlock() - if len(b.backlog) == 0 { - select { - case b.c <- t: - return - default: - } - } - b.backlog = append(b.backlog, t) -} - -func (b *scStateUpdateBuffer) load() { - b.mu.Lock() - defer b.mu.Unlock() - if len(b.backlog) > 0 { - select { - case b.c <- b.backlog[0]: - b.backlog[0] = nil - b.backlog = b.backlog[1:] - default: - } - } -} - -// get returns the channel that the scStateUpdate will be sent to. -// -// Upon receiving, the caller should call load to send another -// scStateChangeTuple onto the channel if there is any. -func (b *scStateUpdateBuffer) get() <-chan *scStateUpdate { - return b.c -} - // ccBalancerWrapper is a wrapper on top of cc for balancers. // It implements balancer.ClientConn interface. type ccBalancerWrapper struct { - cc *ClientConn - balancerMu sync.Mutex // synchronizes calls to the balancer - balancer balancer.Balancer - stateChangeQueue *scStateUpdateBuffer - done *grpcsync.Event + cc *ClientConn + balancerMu sync.Mutex // synchronizes calls to the balancer + balancer balancer.Balancer + scBuffer *buffer.Unbounded + done *grpcsync.Event mu sync.Mutex subConns map[*acBalancerWrapper]struct{} @@ -98,10 +51,10 @@ type ccBalancerWrapper struct { func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { ccb := &ccBalancerWrapper{ - cc: cc, - stateChangeQueue: newSCStateUpdateBuffer(), - done: grpcsync.NewEvent(), - subConns: make(map[*acBalancerWrapper]struct{}), + cc: cc, + scBuffer: buffer.NewUnbounded(), + done: grpcsync.NewEvent(), + subConns: make(map[*acBalancerWrapper]struct{}), } go ccb.watcher() ccb.balancer = b.Build(ccb, bopts) @@ -113,16 +66,17 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui func (ccb *ccBalancerWrapper) watcher() { for { select { - case t := <-ccb.stateChangeQueue.get(): - ccb.stateChangeQueue.load() + case t := <-ccb.scBuffer.Get(): + ccb.scBuffer.Load() if ccb.done.HasFired() { break } ccb.balancerMu.Lock() + su := t.(*scStateUpdate) if ub, ok := ccb.balancer.(balancer.V2Balancer); ok { - ub.UpdateSubConnState(t.sc, balancer.SubConnState{ConnectivityState: t.state}) + ub.UpdateSubConnState(su.sc, balancer.SubConnState{ConnectivityState: su.state}) } else { - ccb.balancer.HandleSubConnStateChange(t.sc, t.state) + ccb.balancer.HandleSubConnStateChange(su.sc, su.state) } ccb.balancerMu.Unlock() case <-ccb.done.Done(): @@ -158,7 +112,7 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co if sc == nil { return } - ccb.stateChangeQueue.put(&scStateUpdate{ + ccb.scBuffer.Put(&scStateUpdate{ sc: sc, state: s, }) diff --git a/internal/buffer/unbounded.go b/internal/buffer/unbounded.go new file mode 100644 index 00000000000..2cb3109d807 --- /dev/null +++ b/internal/buffer/unbounded.go @@ -0,0 +1,78 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package buffer provides an implementation of an unbounded buffer. +package buffer + +import "sync" + +// Unbounded is an implementation of an unbounded buffer which does not use +// extra goroutines. This is typically used for passing updates from one entity +// to another within gRPC. +// +// All methods on this type are thread-safe and don't block on anything except +// the underlying mutex used for synchronization. +type Unbounded struct { + c chan interface{} + mu sync.Mutex + backlog []interface{} +} + +// NewUnbounded returns a new instance of Unbounded. +func NewUnbounded() *Unbounded { + return &Unbounded{c: make(chan interface{}, 1)} +} + +// Put adds t to the unbounded buffer. +func (b *Unbounded) Put(t interface{}) { + b.mu.Lock() + if len(b.backlog) == 0 { + select { + case b.c <- t: + b.mu.Unlock() + return + default: + } + } + b.backlog = append(b.backlog, t) + b.mu.Unlock() +} + +// Load sends the earliest buffered data, if any, onto the read channel +// returned by Get(). Users are expected to call this every time they read a +// value from the read channel. +func (b *Unbounded) Load() { + b.mu.Lock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog[0] = nil + b.backlog = b.backlog[1:] + default: + } + } + b.mu.Unlock() +} + +// Get returns a read channel on which values added to the buffer, via Put(), +// are sent on. +// +// Upon reading a value from this channel, users are expected to call Load() to +// send the next buffered value onto the channel if there is any. +func (b *Unbounded) Get() <-chan interface{} { + return b.c +} diff --git a/internal/buffer/unbounded_test.go b/internal/buffer/unbounded_test.go new file mode 100644 index 00000000000..c8067019ba8 --- /dev/null +++ b/internal/buffer/unbounded_test.go @@ -0,0 +1,111 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package buffer + +import ( + "reflect" + "sort" + "sync" + "testing" +) + +const ( + numWriters = 10 + numWrites = 10 +) + +// wantReads contains the set of values expected to be read by the reader +// goroutine in the tests. +var wantReads []int + +func init() { + for i := 0; i < numWriters; i++ { + for j := 0; j < numWrites; j++ { + wantReads = append(wantReads, i) + } + } +} + +// TestSingleWriter starts one reader and one writer goroutine and makes sure +// that the reader gets all the value added to the buffer by the writer. +func TestSingleWriter(t *testing.T) { + ub := NewUnbounded() + reads := []int{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ch := ub.Get() + for i := 0; i < numWriters*numWrites; i++ { + r := <-ch + reads = append(reads, r.(int)) + ub.Load() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numWriters; i++ { + for j := 0; j < numWrites; j++ { + ub.Put(i) + } + } + }() + + wg.Wait() + if !reflect.DeepEqual(reads, wantReads) { + t.Errorf("reads: %#v, wantReads: %#v", reads, wantReads) + } +} + +// TestMultipleWriters starts multiple writers and one reader goroutine and +// makes sure that the reader gets all the data written by all writers. +func TestMultipleWriters(t *testing.T) { + ub := NewUnbounded() + reads := []int{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ch := ub.Get() + for i := 0; i < numWriters*numWrites; i++ { + r := <-ch + reads = append(reads, r.(int)) + ub.Load() + } + }() + + wg.Add(numWriters) + for i := 0; i < numWriters; i++ { + go func(index int) { + defer wg.Done() + for j := 0; j < numWrites; j++ { + ub.Put(index) + } + }(i) + } + + wg.Wait() + sort.Ints(reads) + if !reflect.DeepEqual(reads, wantReads) { + t.Errorf("reads: %#v, wantReads: %#v", reads, wantReads) + } +}