Skip to content

Commit

Permalink
zmq4: resend subscriptions in socket.addConn
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul authored and sbinet committed Jun 17, 2022
1 parent 16d169c commit c17962e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
21 changes: 14 additions & 7 deletions socket.go
Expand Up @@ -30,12 +30,13 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand Down Expand Up @@ -273,6 +274,7 @@ connect:

func (sck *socket) addConn(c *Conn) {
sck.mu.Lock()
defer sck.mu.Unlock()
sck.conns = append(sck.conns, c)
uuid, ok := c.Peer.Meta[sysSockID]
if !ok {
Expand All @@ -286,7 +288,12 @@ func (sck *socket) addConn(c *Conn) {
if sck.r != nil {
sck.r.addConn(c)
}
sck.mu.Unlock()
// resend subscriptions for topics if there are any
if sck.subTopics != nil {
for _, topic := range sck.subTopics() {
_ = sck.Send(NewMsg(append([]byte{1}, topic...)))
}
}
}

func (sck *socket) rmConn(c *Conn) {
Expand Down
40 changes: 40 additions & 0 deletions socket_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -220,3 +221,42 @@ func TestConnReaperDeadlock(t *testing.T) {
clients[i].Close()
}
}

func TestSocketSendSubscriptionOnConnect(t *testing.T) {
endpoint := "inproc://test-resub"
message := "test"

sub := zmq4.NewSub(context.Background())
defer sub.Close()
pub := zmq4.NewPub(context.Background())
defer pub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
if err := sub.Listen(endpoint); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}
if err := pub.Dial(endpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
wg := new(sync.WaitGroup)
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
for {
pub.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}()
msg, err := sub.Recv()
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}
1 change: 1 addition & 0 deletions sub.go
Expand Up @@ -16,6 +16,7 @@ import (
func NewSub(ctx context.Context, opts ...Option) Socket {
sub := &subSocket{sck: newSocket(ctx, Sub, opts...)}
sub.sck.r = newQReader(sub.sck.ctx)
sub.sck.subTopics = sub.Topics
sub.topics = make(map[string]struct{})
return sub
}
Expand Down

0 comments on commit c17962e

Please sign in to comment.