Skip to content

Commit

Permalink
zmq4: add option for automatic reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
thielepaul committed Jun 17, 2022
1 parent c17962e commit d0375cc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 10 deletions.
8 changes: 8 additions & 0 deletions options.go
Expand Up @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option {
}
}

// WithAutomaticReconnect allows to configure a socket to automatically
// reconnect on connection loss.
func WithAutomaticReconnect(automaticReconnect bool) Option {
return func(s *socket) {
s.autoReconnect = automaticReconnect
}
}

/*
// TODO(sbinet)
Expand Down
29 changes: 19 additions & 10 deletions socket.go
Expand Up @@ -30,13 +30,14 @@ var (

// socket implements the ZeroMQ socket interface
type socket struct {
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string
ep string // socket end-point
typ SocketType
id SocketIdentity
retry time.Duration
sec Security
log *log.Logger
subTopics func() []string
autoReconnect bool

mu sync.RWMutex
ids map[string]*Conn // ZMTP connection IDs
Expand All @@ -51,8 +52,9 @@ type socket struct {
listener net.Listener
dialer net.Dialer

closedConns []*Conn
reaperCond *sync.Cond
closedConns []*Conn
reaperCond *sync.Cond
reaperStarted bool
}

func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
Expand Down Expand Up @@ -267,7 +269,10 @@ connect:
return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint)
}

go sck.connReaper()
if !sck.reaperStarted {
go sck.connReaper()
sck.reaperStarted = true
}
sck.addConn(zconn)
return nil
}
Expand Down Expand Up @@ -326,6 +331,10 @@ func (sck *socket) scheduleRmConn(c *Conn) {
sck.closedConns = append(sck.closedConns, c)
sck.reaperCond.Signal()
sck.reaperCond.L.Unlock()

if sck.autoReconnect {
sck.Dial(sck.ep)
}
}

// Type returns the type of this Socket (PUB, SUB, ...)
Expand Down
66 changes: 66 additions & 0 deletions socket_test.go
Expand Up @@ -6,6 +6,7 @@ package zmq4_test

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -260,3 +261,68 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
}

func TestSocketAutomaticReconnect(t *testing.T) {
listenEndpoint := "tcp://*:1234"
dialEndpoint := "tcp://localhost:1234"
message := "test"

ctx, cancel := context.WithCancel(context.Background())

wg := new(sync.WaitGroup)
defer wg.Wait()
defer cancel()
sendMessages := func(socket zmq4.Socket) {
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
for {
socket.Send(zmq4.NewMsgFromString([]string{message}))
if ctx.Err() != nil {
return
}
time.Sleep(1 * time.Millisecond)
}
}(t)
}

sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true))
defer sub.Close()
sub.SetOption(zmq4.OptionSubscribe, message)
pub := zmq4.NewPub(context.Background())
if err := pub.Listen(dialEndpoint); err != nil {
t.Fatalf("Pub Dial failed: %v", err)
}
if err := sub.Dial(listenEndpoint); err != nil {
t.Fatalf("Sub Dial failed: %v", err)
}

sendMessages(pub)

checkConnectionWorking := func(socket zmq4.Socket) {
for {
msg, err := socket.Recv()
if errors.Is(err, io.EOF) {
continue
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
if string(msg.Frames[0]) != message {
t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message)
}
return
}
}

checkConnectionWorking(sub)
pub.Close()

pub2 := zmq4.NewPub(context.Background())
defer pub2.Close()
if err := pub2.Listen(listenEndpoint); err != nil {
t.Fatalf("Sub Listen failed: %v", err)
}
sendMessages(pub2)
checkConnectionWorking(sub)
}

0 comments on commit d0375cc

Please sign in to comment.