From c17962e2add883a73102240d973d48323a918ae3 Mon Sep 17 00:00:00 2001 From: Paul Thiele Date: Thu, 9 Jun 2022 21:16:57 +0200 Subject: [PATCH] zmq4: resend subscriptions in socket.addConn --- socket.go | 21 ++++++++++++++------- socket_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ sub.go | 1 + 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/socket.go b/socket.go index 6ea5719..55a4f56 100644 --- a/socket.go +++ b/socket.go @@ -30,12 +30,13 @@ var ( // socket implements the ZeroMQ socket interface type socket struct { - ep string // socket end-point - typ SocketType - id SocketIdentity - retry time.Duration - sec Security - log *log.Logger + ep string // socket end-point + typ SocketType + id SocketIdentity + retry time.Duration + sec Security + log *log.Logger + subTopics func() []string mu sync.RWMutex ids map[string]*Conn // ZMTP connection IDs @@ -273,6 +274,7 @@ connect: func (sck *socket) addConn(c *Conn) { sck.mu.Lock() + defer sck.mu.Unlock() sck.conns = append(sck.conns, c) uuid, ok := c.Peer.Meta[sysSockID] if !ok { @@ -286,7 +288,12 @@ func (sck *socket) addConn(c *Conn) { if sck.r != nil { sck.r.addConn(c) } - sck.mu.Unlock() + // resend subscriptions for topics if there are any + if sck.subTopics != nil { + for _, topic := range sck.subTopics() { + _ = sck.Send(NewMsg(append([]byte{1}, topic...))) + } + } } func (sck *socket) rmConn(c *Conn) { diff --git a/socket_test.go b/socket_test.go index a1ea891..8ba617a 100644 --- a/socket_test.go +++ b/socket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "sync" "testing" "time" @@ -220,3 +221,42 @@ func TestConnReaperDeadlock(t *testing.T) { clients[i].Close() } } + +func TestSocketSendSubscriptionOnConnect(t *testing.T) { + endpoint := "inproc://test-resub" + message := "test" + + sub := zmq4.NewSub(context.Background()) + defer sub.Close() + pub := zmq4.NewPub(context.Background()) + defer pub.Close() + sub.SetOption(zmq4.OptionSubscribe, message) + if err := sub.Listen(endpoint); err != nil { + t.Fatalf("Sub Dial failed: %v", err) + } + if err := pub.Dial(endpoint); err != nil { + t.Fatalf("Pub Dial failed: %v", err) + } + wg := new(sync.WaitGroup) + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + wg.Add(1) + go func() { + defer wg.Done() + for { + pub.Send(zmq4.NewMsgFromString([]string{message})) + if ctx.Err() != nil { + return + } + time.Sleep(1 * time.Millisecond) + } + }() + msg, err := sub.Recv() + if err != nil { + t.Fatalf("Recv failed: %v", err) + } + if string(msg.Frames[0]) != message { + t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message) + } +} diff --git a/sub.go b/sub.go index c530226..008d6b1 100644 --- a/sub.go +++ b/sub.go @@ -16,6 +16,7 @@ import ( func NewSub(ctx context.Context, opts ...Option) Socket { sub := &subSocket{sck: newSocket(ctx, Sub, opts...)} sub.sck.r = newQReader(sub.sck.ctx) + sub.sck.subTopics = sub.Topics sub.topics = make(map[string]struct{}) return sub }