diff --git a/options.go b/options.go index d11445c..952347c 100644 --- a/options.go +++ b/options.go @@ -51,6 +51,14 @@ func WithLogger(msg *log.Logger) Option { } } +// WithAutomaticReconnect allows to configure a socket to automatically +// reconnect on connection loss. +func WithAutomaticReconnect(automaticReconnect bool) Option { + return func(s *socket) { + s.autoReconnect = automaticReconnect + } +} + /* // TODO(sbinet) diff --git a/socket.go b/socket.go index 55a4f56..2c7033c 100644 --- a/socket.go +++ b/socket.go @@ -30,13 +30,14 @@ var ( // socket implements the ZeroMQ socket interface type socket struct { - ep string // socket end-point - typ SocketType - id SocketIdentity - retry time.Duration - sec Security - log *log.Logger - subTopics func() []string + ep string // socket end-point + typ SocketType + id SocketIdentity + retry time.Duration + sec Security + log *log.Logger + subTopics func() []string + autoReconnect bool mu sync.RWMutex ids map[string]*Conn // ZMTP connection IDs @@ -51,8 +52,9 @@ type socket struct { listener net.Listener dialer net.Dialer - closedConns []*Conn - reaperCond *sync.Cond + closedConns []*Conn + reaperCond *sync.Cond + reaperStarted bool } func newDefaultSocket(ctx context.Context, sockType SocketType) *socket { @@ -267,7 +269,10 @@ connect: return fmt.Errorf("zmq4: got a nil ZMTP connection to %q", endpoint) } - go sck.connReaper() + if !sck.reaperStarted { + go sck.connReaper() + sck.reaperStarted = true + } sck.addConn(zconn) return nil } @@ -326,6 +331,10 @@ func (sck *socket) scheduleRmConn(c *Conn) { sck.closedConns = append(sck.closedConns, c) sck.reaperCond.Signal() sck.reaperCond.L.Unlock() + + if sck.autoReconnect { + sck.Dial(sck.ep) + } } // Type returns the type of this Socket (PUB, SUB, ...) diff --git a/socket_test.go b/socket_test.go index 8ba617a..95c51ec 100644 --- a/socket_test.go +++ b/socket_test.go @@ -6,6 +6,7 @@ package zmq4_test import ( "context" + "errors" "fmt" "io" "net" @@ -260,3 +261,68 @@ func TestSocketSendSubscriptionOnConnect(t *testing.T) { t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message) } } + +func TestSocketAutomaticReconnect(t *testing.T) { + listenEndpoint := "tcp://*:1234" + dialEndpoint := "tcp://localhost:1234" + message := "test" + + ctx, cancel := context.WithCancel(context.Background()) + + wg := new(sync.WaitGroup) + defer wg.Wait() + defer cancel() + sendMessages := func(socket zmq4.Socket) { + wg.Add(1) + go func(t *testing.T) { + defer wg.Done() + for { + socket.Send(zmq4.NewMsgFromString([]string{message})) + if ctx.Err() != nil { + return + } + time.Sleep(1 * time.Millisecond) + } + }(t) + } + + sub := zmq4.NewSub(context.Background(), zmq4.WithAutomaticReconnect(true)) + defer sub.Close() + sub.SetOption(zmq4.OptionSubscribe, message) + pub := zmq4.NewPub(context.Background()) + if err := pub.Listen(dialEndpoint); err != nil { + t.Fatalf("Pub Dial failed: %v", err) + } + if err := sub.Dial(listenEndpoint); err != nil { + t.Fatalf("Sub Dial failed: %v", err) + } + + sendMessages(pub) + + checkConnectionWorking := func(socket zmq4.Socket) { + for { + msg, err := socket.Recv() + if errors.Is(err, io.EOF) { + continue + } + if err != nil { + t.Fatalf("Recv failed: %v", err) + } + if string(msg.Frames[0]) != message { + t.Fatalf("invalid message received: got '%s', wanted '%s'", msg.Frames[0], message) + } + return + } + } + + checkConnectionWorking(sub) + pub.Close() + + pub2 := zmq4.NewPub(context.Background()) + defer pub2.Close() + if err := pub2.Listen(listenEndpoint); err != nil { + t.Fatalf("Sub Listen failed: %v", err) + } + sendMessages(pub2) + checkConnectionWorking(sub) +}