From 1c44bc8802061b319927aa48b50cba4ad02726ac Mon Sep 17 00:00:00 2001 From: thomassong Date: Fri, 10 Feb 2023 10:18:50 +0800 Subject: [PATCH] fix: avoid deadlock in publisher and subscriber (#1749) #1751 --- db.go | 25 ++++++++++++++++----- publisher.go | 25 +++++++++++++++------ publisher_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 12 deletions(-) diff --git a/db.go b/db.go index 900f03965..2304a04d9 100644 --- a/db.go +++ b/db.go @@ -1552,11 +1552,11 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes } c := y.NewCloser(1) - recvCh, id := db.pub.newSubscriber(c, prefixes...) + s := db.pub.newSubscriber(c, prefixes...) slurp := func(batch *pb.KVList) error { for { select { - case kvs := <-recvCh: + case kvs := <-s.sendCh: batch.Kv = append(batch.Kv, kvs.Kv...) default: if len(batch.GetKv()) > 0 { @@ -1566,6 +1566,17 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes } } } + + drain := func() { + for { + select { + case <-s.sendCh: + default: + return + } + } + } + for { select { case <-c.HasBeenClosed(): @@ -1577,15 +1588,19 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes return err case <-ctx.Done(): c.Done() - db.pub.deleteSubscriber(id) + atomic.StoreUint64(s.active, 0) + drain() + db.pub.deleteSubscriber(s.id) // Delete the subscriber to avoid further updates. return ctx.Err() - case batch := <-recvCh: + case batch := <-s.sendCh: err := slurp(batch) if err != nil { c.Done() + atomic.StoreUint64(s.active, 0) + drain() // Delete the subscriber if there is an error by the callback. - db.pub.deleteSubscriber(id) + db.pub.deleteSubscriber(s.id) return err } } diff --git a/publisher.go b/publisher.go index 7458b0d95..308d114ad 100644 --- a/publisher.go +++ b/publisher.go @@ -18,6 +18,7 @@ package badger import ( "sync" + "sync/atomic" "github.com/dgraph-io/badger/pb" "github.com/dgraph-io/badger/trie" @@ -25,15 +26,19 @@ import ( ) type subscriber struct { + id uint64 prefixes [][]byte - sendCh chan<- *pb.KVList + sendCh chan *pb.KVList subCloser *y.Closer + // this will be atomic pointer which will be used to + // track whether the subscriber is active or not + active *uint64 } type publisher struct { sync.Mutex pubCh chan requests - subscribers map[uint64]subscriber + subscribers map[uint64]*subscriber nextID uint64 indexer *trie.Trie } @@ -41,7 +46,7 @@ type publisher struct { func newPublisher() *publisher { return &publisher{ pubCh: make(chan requests, 1000), - subscribers: make(map[uint64]subscriber), + subscribers: make(map[uint64]*subscriber), nextID: 0, indexer: trie.NewTrie(), } @@ -104,26 +109,32 @@ func (p *publisher) publishUpdates(reqs requests) { } for id, kvs := range batchedUpdates { - p.subscribers[id].sendCh <- kvs + if atomic.LoadUint64(p.subscribers[id].active) == 1 { + p.subscribers[id].sendCh <- kvs + } } } -func (p *publisher) newSubscriber(c *y.Closer, prefixes ...[]byte) (<-chan *pb.KVList, uint64) { +func (p *publisher) newSubscriber(c *y.Closer, prefixes ...[]byte) *subscriber { p.Lock() defer p.Unlock() ch := make(chan *pb.KVList, 1000) id := p.nextID // Increment next ID. p.nextID++ - p.subscribers[id] = subscriber{ + active := uint64(1) + s := &subscriber{ + active: &active, + id: id, prefixes: prefixes, sendCh: ch, subCloser: c, } + p.subscribers[id] = s for _, prefix := range prefixes { p.indexer.Add(prefix, id) } - return ch, id + return s } // cleanSubscribers stops all the subscribers. Ideally, It should be called while closing DB. diff --git a/publisher_test.go b/publisher_test.go index ce61232e1..216e9b704 100644 --- a/publisher_test.go +++ b/publisher_test.go @@ -19,8 +19,11 @@ import ( "context" "fmt" "sync" + "sync/atomic" "testing" + "time" + "github.com/pkg/errors" "github.com/stretchr/testify/require" "github.com/dgraph-io/badger/pb" @@ -101,3 +104,56 @@ func TestMultiplePrefix(t *testing.T) { wg.Wait() }) } + +func TestPublisherDeadlock(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + var subWg sync.WaitGroup + subWg.Add(1) + + var firstUpdate sync.WaitGroup + firstUpdate.Add(1) + + var subDone sync.WaitGroup + subDone.Add(1) + go func() { + subWg.Done() + match := []byte("ke") + err := db.Subscribe(context.Background(), func(kvs *pb.KVList) error { + firstUpdate.Done() + time.Sleep(time.Second * 20) + return errors.New("error returned") + }, match) + require.Error(t, err, errors.New("error returned")) + subDone.Done() + }() + subWg.Wait() + go func() { + err := db.Update(func(txn *Txn) error { + e := NewEntry([]byte(fmt.Sprintf("key%d", 0)), []byte(fmt.Sprintf("value%d", 0))) + return txn.SetEntry(e) + }) + require.NoError(t, err) + }() + + firstUpdate.Wait() + req := int64(0) + for i := 1; i < 1110; i++ { + time.Sleep(time.Millisecond * 10) + go func(i int) { + err := db.Update(func(txn *Txn) error { + e := NewEntry([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))) + return txn.SetEntry(e) + }) + require.NoError(t, err) + atomic.AddInt64(&req, 1) + }(i) + } + for { + if atomic.LoadInt64(&req) == 1109 { + break + } + time.Sleep(time.Second) + } + subDone.Wait() + }) +}