Skip to content

Commit

Permalink
fix: avoid deadlock in publisher and subscriber (dgraph-io#1749) dgra…
Browse files Browse the repository at this point in the history
  • Loading branch information
mYmNeo committed Feb 13, 2023
1 parent 10be6ca commit 1c44bc8
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 12 deletions.
25 changes: 20 additions & 5 deletions db.go
Expand Up @@ -1552,11 +1552,11 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes
}

c := y.NewCloser(1)
recvCh, id := db.pub.newSubscriber(c, prefixes...)
s := db.pub.newSubscriber(c, prefixes...)
slurp := func(batch *pb.KVList) error {
for {
select {
case kvs := <-recvCh:
case kvs := <-s.sendCh:
batch.Kv = append(batch.Kv, kvs.Kv...)
default:
if len(batch.GetKv()) > 0 {
Expand All @@ -1566,6 +1566,17 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes
}
}
}

drain := func() {
for {
select {
case <-s.sendCh:
default:
return
}
}
}

for {
select {
case <-c.HasBeenClosed():
Expand All @@ -1577,15 +1588,19 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes
return err
case <-ctx.Done():
c.Done()
db.pub.deleteSubscriber(id)
atomic.StoreUint64(s.active, 0)
drain()
db.pub.deleteSubscriber(s.id)
// Delete the subscriber to avoid further updates.
return ctx.Err()
case batch := <-recvCh:
case batch := <-s.sendCh:
err := slurp(batch)
if err != nil {
c.Done()
atomic.StoreUint64(s.active, 0)
drain()
// Delete the subscriber if there is an error by the callback.
db.pub.deleteSubscriber(id)
db.pub.deleteSubscriber(s.id)
return err
}
}
Expand Down
25 changes: 18 additions & 7 deletions publisher.go
Expand Up @@ -18,30 +18,35 @@ package badger

import (
"sync"
"sync/atomic"

"github.com/dgraph-io/badger/pb"
"github.com/dgraph-io/badger/trie"
"github.com/dgraph-io/badger/y"
)

type subscriber struct {
id uint64
prefixes [][]byte
sendCh chan<- *pb.KVList
sendCh chan *pb.KVList
subCloser *y.Closer
// this will be atomic pointer which will be used to
// track whether the subscriber is active or not
active *uint64
}

type publisher struct {
sync.Mutex
pubCh chan requests
subscribers map[uint64]subscriber
subscribers map[uint64]*subscriber
nextID uint64
indexer *trie.Trie
}

func newPublisher() *publisher {
return &publisher{
pubCh: make(chan requests, 1000),
subscribers: make(map[uint64]subscriber),
subscribers: make(map[uint64]*subscriber),
nextID: 0,
indexer: trie.NewTrie(),
}
Expand Down Expand Up @@ -104,26 +109,32 @@ func (p *publisher) publishUpdates(reqs requests) {
}

for id, kvs := range batchedUpdates {
p.subscribers[id].sendCh <- kvs
if atomic.LoadUint64(p.subscribers[id].active) == 1 {
p.subscribers[id].sendCh <- kvs
}
}
}

func (p *publisher) newSubscriber(c *y.Closer, prefixes ...[]byte) (<-chan *pb.KVList, uint64) {
func (p *publisher) newSubscriber(c *y.Closer, prefixes ...[]byte) *subscriber {
p.Lock()
defer p.Unlock()
ch := make(chan *pb.KVList, 1000)
id := p.nextID
// Increment next ID.
p.nextID++
p.subscribers[id] = subscriber{
active := uint64(1)
s := &subscriber{
active: &active,
id: id,
prefixes: prefixes,
sendCh: ch,
subCloser: c,
}
p.subscribers[id] = s
for _, prefix := range prefixes {
p.indexer.Add(prefix, id)
}
return ch, id
return s
}

// cleanSubscribers stops all the subscribers. Ideally, It should be called while closing DB.
Expand Down
56 changes: 56 additions & 0 deletions publisher_test.go
Expand Up @@ -19,8 +19,11 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"

"github.com/dgraph-io/badger/pb"
Expand Down Expand Up @@ -101,3 +104,56 @@ func TestMultiplePrefix(t *testing.T) {
wg.Wait()
})
}

func TestPublisherDeadlock(t *testing.T) {
runBadgerTest(t, nil, func(t *testing.T, db *DB) {
var subWg sync.WaitGroup
subWg.Add(1)

var firstUpdate sync.WaitGroup
firstUpdate.Add(1)

var subDone sync.WaitGroup
subDone.Add(1)
go func() {
subWg.Done()
match := []byte("ke")
err := db.Subscribe(context.Background(), func(kvs *pb.KVList) error {
firstUpdate.Done()
time.Sleep(time.Second * 20)
return errors.New("error returned")
}, match)
require.Error(t, err, errors.New("error returned"))
subDone.Done()
}()
subWg.Wait()
go func() {
err := db.Update(func(txn *Txn) error {
e := NewEntry([]byte(fmt.Sprintf("key%d", 0)), []byte(fmt.Sprintf("value%d", 0)))
return txn.SetEntry(e)
})
require.NoError(t, err)
}()

firstUpdate.Wait()
req := int64(0)
for i := 1; i < 1110; i++ {
time.Sleep(time.Millisecond * 10)
go func(i int) {
err := db.Update(func(txn *Txn) error {
e := NewEntry([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)))
return txn.SetEntry(e)
})
require.NoError(t, err)
atomic.AddInt64(&req, 1)
}(i)
}
for {
if atomic.LoadInt64(&req) == 1109 {
break
}
time.Sleep(time.Second)
}
subDone.Wait()
})
}

0 comments on commit 1c44bc8

Please sign in to comment.