Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] ConnectHandler invoked on initial successful connect #1133

Merged
merged 1 commit into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions nats.go
Expand Up @@ -375,6 +375,12 @@ type Options struct {
// DisconnectedCB will not be called if DisconnectedErrCB is set
DisconnectedErrCB ConnErrHandler

// ConnectedCB sets the connected handler called when the initial connection
// is established. It is not invoked on successful reconnects - for reconnections,
// use ReconnectedCB. ConnectedCB can be used in conjunction with RetryOnFailedConnect
// to detect whether the initial connect was successful.
ConnectedCB ConnHandler

// ReconnectedCB sets the reconnected handler called whenever
// the connection is successfully reconnected.
ReconnectedCB ConnHandler
Expand Down Expand Up @@ -999,6 +1005,14 @@ func DisconnectHandler(cb ConnHandler) Option {
}
}

// ConnectHandler is an Option to set the connected handler.
func ConnectHandler(cb ConnHandler) Option {
return func(o *Options) error {
o.ConnectedCB = cb
return nil
}
}

// ReconnectHandler is an Option to set the reconnected handler.
func ReconnectHandler(cb ConnHandler) Option {
return func(o *Options) error {
Expand Down Expand Up @@ -1367,13 +1381,18 @@ func (o Options) Connect() (*Conn, error) {
// Create reader/writer
nc.newReaderWriter()

if err := nc.connect(); err != nil {
connectionEstablished, err := nc.connect()
if err != nil {
return nil, err
}

// Spin up the async cb dispatcher on success
go nc.ach.asyncCBDispatcher()

if connectionEstablished && nc.Opts.ConnectedCB != nil {
nc.ach.push(func() { nc.Opts.ConnectedCB(nc) })
}

return nc, nil
}

Expand Down Expand Up @@ -2114,9 +2133,10 @@ func (nc *Conn) processConnectInit() error {
return nil
}

// Main connect function. Will connect to the nats-server
func (nc *Conn) connect() error {
// Main connect function. Will connect to the nats-server.
func (nc *Conn) connect() (bool, error) {
var err error
var connectionEstablished bool

// Create actual socket connection
// For first connect we walk all servers in the pool and try
Expand Down Expand Up @@ -2162,6 +2182,7 @@ func (nc *Conn) connect() error {
}

if err == nil {
connectionEstablished = true
nc.initc = false
} else if nc.Opts.RetryOnFailedConnect {
nc.setup()
Expand All @@ -2173,7 +2194,7 @@ func (nc *Conn) connect() error {
nc.current = nil
}

return err
return connectionEstablished, err
}

// This will check to see if the connection should be
Expand Down
111 changes: 100 additions & 11 deletions test/conn_test.go
Expand Up @@ -668,22 +668,29 @@ func TestCallbacksOrder(t *testing.T) {
defer s.Shutdown()

firstDisconnect := true
dtime1 := time.Time{}
dtime2 := time.Time{}
rtime := time.Time{}
atime1 := time.Time{}
atime2 := time.Time{}
ctime := time.Time{}
var connTime, dtime1, dtime2, rtime, atime1, atime2, ctime time.Time

cbErrors := make(chan error, 20)

connected := make(chan bool)
reconnected := make(chan bool)
closed := make(chan bool)
asyncErr := make(chan bool, 2)
recvCh := make(chan bool, 2)
recvCh1 := make(chan bool)
recvCh2 := make(chan bool)

connCh := func(nc *nats.Conn) {
if err := isRunningInAsyncCBDispatcher(); err != nil {
cbErrors <- err
connected <- true
return
}
time.Sleep(50 * time.Millisecond)
connTime = time.Now()
connected <- true
}

dch := func(nc *nats.Conn) {
if err := isRunningInAsyncCBDispatcher(); err != nil {
cbErrors <- err
Expand Down Expand Up @@ -738,6 +745,7 @@ func TestCallbacksOrder(t *testing.T) {
url = "nats://" + url + "," + nats.DefaultURL

nc, err := nats.Connect(url,
nats.ConnectHandler(connCh),
nats.DisconnectHandler(dch),
nats.ReconnectHandler(rch),
nats.ClosedHandler(cch),
Expand All @@ -751,6 +759,12 @@ func TestCallbacksOrder(t *testing.T) {
}
defer nc.Close()

// Wait for notification on connection established
err = Wait(connected)
if err != nil {
t.Fatal("Did not get the connected callback")
}

ncp, err := nats.Connect(nats.DefaultURL,
nats.ReconnectWait(50*time.Millisecond))
if err != nil {
Expand All @@ -771,8 +785,7 @@ func TestCallbacksOrder(t *testing.T) {
t.Fatal("Did not get the reconnected callback")
}

var sub1 *nats.Subscription
var sub2 *nats.Subscription
var sub1, sub2 *nats.Subscription

recv := func(m *nats.Msg) {
// Signal that one message is received
Expand Down Expand Up @@ -840,12 +853,12 @@ func TestCallbacksOrder(t *testing.T) {
t.Fatalf("%v", <-cbErrors)
}

if (dtime1 == time.Time{}) || (dtime2 == time.Time{}) || (rtime == time.Time{}) || (atime1 == time.Time{}) || (atime2 == time.Time{}) || (ctime == time.Time{}) {
if (connTime == time.Time{}) || (dtime1 == time.Time{}) || (dtime2 == time.Time{}) || (rtime == time.Time{}) || (atime1 == time.Time{}) || (atime2 == time.Time{}) || (ctime == time.Time{}) {
t.Fatalf("Some callbacks did not fire:\n%v\n%v\n%v\n%v\n%v\n%v", dtime1, rtime, atime1, atime2, dtime2, ctime)
}

if rtime.Before(dtime1) || dtime2.Before(rtime) || atime2.Before(atime1) || ctime.Before(atime2) {
t.Fatalf("Wrong callback order:\n%v\n%v\n%v\n%v\n%v\n%v", dtime1, rtime, atime1, atime2, dtime2, ctime)
if dtime1.Before(connTime) || rtime.Before(dtime1) || dtime2.Before(rtime) || atime2.Before(atime1) || ctime.Before(atime2) {
t.Fatalf("Wrong callback order:\n%v\n%v\n%v\n%v\n%v\n%v\n%v", connTime, dtime1, rtime, atime1, atime2, dtime2, ctime)
}

// Close the other connection
Expand All @@ -864,6 +877,82 @@ func TestCallbacksOrder(t *testing.T) {
t.Fatalf("The async callback dispatcher(s) should have stopped")
}

func TestConnectHandler(t *testing.T) {
t.Run("with RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.RetryOnFailedConnect(true))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
})
t.Run("with RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
_, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.RetryOnFailedConnect(true))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
}
})
t.Run("no RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
})
t.Run("no RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
_, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))

if err == nil {
t.Fatalf("Expected error on connect, got nil")
}
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
}
})
}

func TestFlushReleaseOnClose(t *testing.T) {
serverInfo := "INFO {\"server_id\":\"foobar\",\"host\":\"%s\",\"port\":%d,\"auth_required\":false,\"tls_required\":false,\"max_payload\":1048576}\r\n"

Expand Down