From 2ae0c5804082f67e1f18492583b0390c3120645c Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Tue, 15 Nov 2022 14:55:45 +0100 Subject: [PATCH] Add ConnectHandler invoked on initial successful connect --- nats.go | 29 ++++++++++-- test/conn_test.go | 111 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 125 insertions(+), 15 deletions(-) diff --git a/nats.go b/nats.go index a5e2e4a4b..413d67467 100644 --- a/nats.go +++ b/nats.go @@ -375,6 +375,12 @@ type Options struct { // DisconnectedCB will not be called if DisconnectedErrCB is set DisconnectedErrCB ConnErrHandler + // ConnectedCB sets the connected handler called when the initial connection + // is established. It is not invoked on successful reconnects - for reconnections, + // use ReconnectedCB. ConnectedCB can be used in conjunction with RetryOnFailedConnect + // to detect whether the initial connect was successful. + ConnectedCB ConnHandler + // ReconnectedCB sets the reconnected handler called whenever // the connection is successfully reconnected. ReconnectedCB ConnHandler @@ -999,6 +1005,14 @@ func DisconnectHandler(cb ConnHandler) Option { } } +// ConnectHandler is an Option to set the connected handler. +func ConnectHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.ConnectedCB = cb + return nil + } +} + // ReconnectHandler is an Option to set the reconnected handler. func ReconnectHandler(cb ConnHandler) Option { return func(o *Options) error { @@ -1367,13 +1381,18 @@ func (o Options) Connect() (*Conn, error) { // Create reader/writer nc.newReaderWriter() - if err := nc.connect(); err != nil { + connectionEstablished, err := nc.connect() + if err != nil { return nil, err } // Spin up the async cb dispatcher on success go nc.ach.asyncCBDispatcher() + if connectionEstablished && nc.Opts.ConnectedCB != nil { + nc.ach.push(func() { nc.Opts.ConnectedCB(nc) }) + } + return nc, nil } @@ -2114,9 +2133,10 @@ func (nc *Conn) processConnectInit() error { return nil } -// Main connect function. Will connect to the nats-server -func (nc *Conn) connect() error { +// Main connect function. Will connect to the nats-server. +func (nc *Conn) connect() (bool, error) { var err error + var connectionEstablished bool // Create actual socket connection // For first connect we walk all servers in the pool and try @@ -2162,6 +2182,7 @@ func (nc *Conn) connect() error { } if err == nil { + connectionEstablished = true nc.initc = false } else if nc.Opts.RetryOnFailedConnect { nc.setup() @@ -2173,7 +2194,7 @@ func (nc *Conn) connect() error { nc.current = nil } - return err + return connectionEstablished, err } // This will check to see if the connection should be diff --git a/test/conn_test.go b/test/conn_test.go index 463070b36..2726c24d2 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -668,15 +668,11 @@ func TestCallbacksOrder(t *testing.T) { defer s.Shutdown() firstDisconnect := true - dtime1 := time.Time{} - dtime2 := time.Time{} - rtime := time.Time{} - atime1 := time.Time{} - atime2 := time.Time{} - ctime := time.Time{} + var connTime, dtime1, dtime2, rtime, atime1, atime2, ctime time.Time cbErrors := make(chan error, 20) + connected := make(chan bool) reconnected := make(chan bool) closed := make(chan bool) asyncErr := make(chan bool, 2) @@ -684,6 +680,17 @@ func TestCallbacksOrder(t *testing.T) { recvCh1 := make(chan bool) recvCh2 := make(chan bool) + connCh := func(nc *nats.Conn) { + if err := isRunningInAsyncCBDispatcher(); err != nil { + cbErrors <- err + connected <- true + return + } + time.Sleep(50 * time.Millisecond) + connTime = time.Now() + connected <- true + } + dch := func(nc *nats.Conn) { if err := isRunningInAsyncCBDispatcher(); err != nil { cbErrors <- err @@ -738,6 +745,7 @@ func TestCallbacksOrder(t *testing.T) { url = "nats://" + url + "," + nats.DefaultURL nc, err := nats.Connect(url, + nats.ConnectHandler(connCh), nats.DisconnectHandler(dch), nats.ReconnectHandler(rch), nats.ClosedHandler(cch), @@ -751,6 +759,12 @@ func TestCallbacksOrder(t *testing.T) { } defer nc.Close() + // Wait for notification on connection established + err = Wait(connected) + if err != nil { + t.Fatal("Did not get the connected callback") + } + ncp, err := nats.Connect(nats.DefaultURL, nats.ReconnectWait(50*time.Millisecond)) if err != nil { @@ -771,8 +785,7 @@ func TestCallbacksOrder(t *testing.T) { t.Fatal("Did not get the reconnected callback") } - var sub1 *nats.Subscription - var sub2 *nats.Subscription + var sub1, sub2 *nats.Subscription recv := func(m *nats.Msg) { // Signal that one message is received @@ -840,12 +853,12 @@ func TestCallbacksOrder(t *testing.T) { t.Fatalf("%v", <-cbErrors) } - if (dtime1 == time.Time{}) || (dtime2 == time.Time{}) || (rtime == time.Time{}) || (atime1 == time.Time{}) || (atime2 == time.Time{}) || (ctime == time.Time{}) { + if (connTime == time.Time{}) || (dtime1 == time.Time{}) || (dtime2 == time.Time{}) || (rtime == time.Time{}) || (atime1 == time.Time{}) || (atime2 == time.Time{}) || (ctime == time.Time{}) { t.Fatalf("Some callbacks did not fire:\n%v\n%v\n%v\n%v\n%v\n%v", dtime1, rtime, atime1, atime2, dtime2, ctime) } - if rtime.Before(dtime1) || dtime2.Before(rtime) || atime2.Before(atime1) || ctime.Before(atime2) { - t.Fatalf("Wrong callback order:\n%v\n%v\n%v\n%v\n%v\n%v", dtime1, rtime, atime1, atime2, dtime2, ctime) + if dtime1.Before(connTime) || rtime.Before(dtime1) || dtime2.Before(rtime) || atime2.Before(atime1) || ctime.Before(atime2) { + t.Fatalf("Wrong callback order:\n%v\n%v\n%v\n%v\n%v\n%v\n%v", connTime, dtime1, rtime, atime1, atime2, dtime2, ctime) } // Close the other connection @@ -864,6 +877,82 @@ func TestCallbacksOrder(t *testing.T) { t.Fatalf("The async callback dispatcher(s) should have stopped") } +func TestConnectHandler(t *testing.T) { + t.Run("with RetryOnFailedConnect, connection established", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + connected := make(chan bool) + connHandler := func(*nats.Conn) { + connected <- true + } + nc, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(connHandler), + nats.RetryOnFailedConnect(true)) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + if err = Wait(connected); err != nil { + t.Fatal("Timeout waiting for connect handler") + } + }) + t.Run("with RetryOnFailedConnect, connection failed", func(t *testing.T) { + connected := make(chan bool) + connHandler := func(*nats.Conn) { + connected <- true + } + _, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(connHandler), + nats.RetryOnFailedConnect(true)) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + select { + case <-connected: + t.Fatalf("ConnectedCB invoked when no connection established") + case <-time.After(100 * time.Millisecond): + } + }) + t.Run("no RetryOnFailedConnect, connection established", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + connected := make(chan bool) + connHandler := func(*nats.Conn) { + connected <- true + } + nc, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(connHandler)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + if err = Wait(connected); err != nil { + t.Fatal("Timeout waiting for connect handler") + } + }) + t.Run("no RetryOnFailedConnect, connection failed", func(t *testing.T) { + connected := make(chan bool) + connHandler := func(*nats.Conn) { + connected <- true + } + _, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(connHandler)) + + if err == nil { + t.Fatalf("Expected error on connect, got nil") + } + select { + case <-connected: + t.Fatalf("ConnectedCB invoked when no connection established") + case <-time.After(100 * time.Millisecond): + } + }) +} + func TestFlushReleaseOnClose(t *testing.T) { serverInfo := "INFO {\"server_id\":\"foobar\",\"host\":\"%s\",\"port\":%d,\"auth_required\":false,\"tls_required\":false,\"max_payload\":1048576}\r\n"