Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] Force reconnect #1624

Merged
merged 3 commits into from May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions example_test.go
Expand Up @@ -89,6 +89,19 @@ func ExampleConn_Subscribe() {
})
}

func ExampleConn_ForceReconnect() {
nc, _ := nats.Connect(nats.DefaultURL)
defer nc.Close()

nc.Subscribe("foo", func(m *nats.Msg) {
fmt.Printf("Received a message: %s\n", string(m.Data))
})

// Reconnect to the server.
// the subscription will be recreated after the reconnect.
nc.ForceReconnect()
}

// This Example shows a synchronous subscriber.
func ExampleConn_SubscribeSync() {
nc, _ := nats.Connect(nats.DefaultURL)
Expand Down
59 changes: 52 additions & 7 deletions nats.go
Expand Up @@ -2161,6 +2161,47 @@ func (nc *Conn) waitForExits() {
nc.wg.Wait()
}

// ForceReconnect forces a reconnect attempt to the server.
// This is a non-blocking call and will start the reconnect
// process without waiting for it to complete.
//
// If the connection is already in the process of reconnecting,
// this call will force an immediate reconnect attempt (bypassing
// the current reconnect delay).
func (nc *Conn) ForceReconnect() error {
nc.mu.Lock()
defer nc.mu.Unlock()

if nc.isClosed() {
return ErrConnectionClosed
}
if nc.isReconnecting() {
// if we're already reconnecting, force a reconnect attempt
// even if we're in the middle of a backoff
if nc.rqch != nil {
close(nc.rqch)
}
return nil
}

// Clear any queued pongs
nc.clearPendingFlushCalls()

// Clear any queued and blocking requests.
nc.clearPendingRequestCalls()

// Stop ping timer if set.
nc.stopPingTimer()

// Go ahead and make sure we have flushed the outbound
nc.bw.flush()
nc.conn.Close()

nc.changeConnStatus(RECONNECTING)
go nc.doReconnect(nil, true)
return nil
}

// ConnectedUrl reports the connected server's URL
func (nc *Conn) ConnectedUrl() string {
if nc == nil {
Expand Down Expand Up @@ -2420,7 +2461,7 @@ func (nc *Conn) connect() (bool, error) {
nc.setup()
nc.changeConnStatus(RECONNECTING)
nc.bw.switchToPending()
go nc.doReconnect(ErrNoServers)
go nc.doReconnect(ErrNoServers, false)
err = nil
} else {
nc.current = nil
Expand Down Expand Up @@ -2720,7 +2761,7 @@ func (nc *Conn) stopPingTimer() {

// Try to reconnect using the option parameters.
// This function assumes we are allowed to reconnect.
func (nc *Conn) doReconnect(err error) {
func (nc *Conn) doReconnect(err error, forceReconnect bool) {
// We want to make sure we have the other watchers shutdown properly
// here before we proceed past this point.
nc.waitForExits()
Expand Down Expand Up @@ -2776,7 +2817,8 @@ func (nc *Conn) doReconnect(err error) {
break
}

doSleep := i+1 >= len(nc.srvPool)
doSleep := i+1 >= len(nc.srvPool) && !forceReconnect
forceReconnect = false
nc.mu.Unlock()

if !doSleep {
Expand All @@ -2803,6 +2845,12 @@ func (nc *Conn) doReconnect(err error) {
select {
case <-rqch:
rt.Stop()

// we need to reset the rqch channel to avoid
// closing a closed channel in the next iteration
nc.mu.Lock()
nc.rqch = make(chan struct{})
nc.mu.Unlock()
case <-rt.C:
}
}
Expand Down Expand Up @@ -2872,9 +2920,6 @@ func (nc *Conn) doReconnect(err error) {
// Done with the pending buffer
nc.bw.doneWithPending()

// This is where we are truly connected.
nc.status = CONNECTED

// Queue up the correct callback. If we are in initial connect state
// (using retry on failed connect), we will call the ConnectedCB,
// otherwise the ReconnectedCB.
Expand Down Expand Up @@ -2930,7 +2975,7 @@ func (nc *Conn) processOpErr(err error) {
// Clear any queued pongs, e.g. pending flush calls.
nc.clearPendingFlushCalls()

go nc.doReconnect(err)
go nc.doReconnect(err, false)
nc.mu.Unlock()
return
}
Expand Down
18 changes: 4 additions & 14 deletions test/conn_test.go
Expand Up @@ -2946,16 +2946,6 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
}

func TestConnStatusChangedEvents(t *testing.T) {
waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) {
select {
case s := <-ch:
if s != expected {
t.Fatalf("Expected status: %s; got: %s", expected, s)
}
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for status %q", expected)
}
}
t.Run("default events", func(t *testing.T) {
s := RunDefaultServer()
nc, err := nats.Connect(s.ClientURL())
Expand All @@ -2978,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) {
time.Sleep(50 * time.Millisecond)

s.Shutdown()
waitForStatus(t, newStatus, nats.RECONNECTING)
WaitOnChannel(t, newStatus, nats.RECONNECTING)

s = RunDefaultServer()
defer s.Shutdown()

waitForStatus(t, newStatus, nats.CONNECTED)
WaitOnChannel(t, newStatus, nats.CONNECTED)

nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
WaitOnChannel(t, newStatus, nats.CLOSED)

select {
case s := <-newStatus:
Expand Down Expand Up @@ -3019,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) {
s = RunDefaultServer()
defer s.Shutdown()
nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
WaitOnChannel(t, newStatus, nats.CLOSED)

select {
case s := <-newStatus:
Expand Down
12 changes: 12 additions & 0 deletions test/helper_test.go
Expand Up @@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error {
return errors.New("timeout")
}

func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) {
t.Helper()
select {
case s := <-ch:
if s != expected {
t.Fatalf("Expected result: %v; got: %v", expected, s)
}
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for result %v", expected)
}
}

func stackFatalf(t tLogger, f string, args ...any) {
lines := make([]string, 0, 32)
msg := fmt.Sprintf(f, args...)
Expand Down