Skip to content

Commit

Permalink
[ADDED] Force reconnect to the server
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Apr 27, 2024
1 parent 8894a27 commit 3242548
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 52 deletions.
13 changes: 13 additions & 0 deletions example_test.go
Expand Up @@ -89,6 +89,19 @@ func ExampleConn_Subscribe() {
})
}

func ExampleConn_Reconnect() {
nc, _ := nats.Connect(nats.DefaultURL)
defer nc.Close()

nc.Subscribe("foo", func(m *nats.Msg) {
fmt.Printf("Received a message: %s\n", string(m.Data))
})

// Reconnect to the server.
// the subscription will be recreated after the reconnect.
nc.Reconnect()
}

// This Example shows a synchronous subscriber.
func ExampleConn_SubscribeSync() {
nc, _ := nats.Connect(nats.DefaultURL)
Expand Down
59 changes: 52 additions & 7 deletions nats.go
Expand Up @@ -2161,6 +2161,47 @@ func (nc *Conn) waitForExits() {
nc.wg.Wait()
}

// Reconnect forces a reconnect attempt to the server.
// This is a non-blocking call and will start the reconnect
// process without waiting for it to complete.
//
// If the connection is already in the process of reconnecting,
// this call will force an immediate reconnect attempt (bypassing
// the current reconnect delay).
func (nc *Conn) Reconnect() error {
nc.mu.Lock()
defer nc.mu.Unlock()

if nc.isClosed() {
return ErrConnectionClosed
}
if nc.isReconnecting() {
// if we're already reconnecting, force a reconnect attempt
// even if we're in the middle of a backoff
if nc.rqch != nil {
close(nc.rqch)
}
return nil
}

// Clear any queued pongs
nc.clearPendingFlushCalls()

// Clear any queued and blocking requests.
nc.clearPendingRequestCalls()

// Stop ping timer if set.
nc.stopPingTimer()

// Go ahead and make sure we have flushed the outbound
nc.bw.flush()
nc.conn.Close()

nc.changeConnStatus(RECONNECTING)
go nc.doReconnect(nil, true)
return nil
}

// ConnectedUrl reports the connected server's URL
func (nc *Conn) ConnectedUrl() string {
if nc == nil {
Expand Down Expand Up @@ -2420,7 +2461,7 @@ func (nc *Conn) connect() (bool, error) {
nc.setup()
nc.changeConnStatus(RECONNECTING)
nc.bw.switchToPending()
go nc.doReconnect(ErrNoServers)
go nc.doReconnect(ErrNoServers, false)
err = nil
} else {
nc.current = nil
Expand Down Expand Up @@ -2720,7 +2761,7 @@ func (nc *Conn) stopPingTimer() {

// Try to reconnect using the option parameters.
// This function assumes we are allowed to reconnect.
func (nc *Conn) doReconnect(err error) {
func (nc *Conn) doReconnect(err error, forceReconnect bool) {
// We want to make sure we have the other watchers shutdown properly
// here before we proceed past this point.
nc.waitForExits()
Expand Down Expand Up @@ -2776,7 +2817,8 @@ func (nc *Conn) doReconnect(err error) {
break
}

doSleep := i+1 >= len(nc.srvPool)
doSleep := i+1 >= len(nc.srvPool) && !forceReconnect
forceReconnect = false
nc.mu.Unlock()

if !doSleep {
Expand All @@ -2803,6 +2845,12 @@ func (nc *Conn) doReconnect(err error) {
select {
case <-rqch:
rt.Stop()

// we need to reset the rqch channel to avoid
// closing a closed channel in the next iteration
nc.mu.Lock()
nc.rqch = make(chan struct{})
nc.mu.Unlock()
case <-rt.C:
}
}
Expand Down Expand Up @@ -2872,9 +2920,6 @@ func (nc *Conn) doReconnect(err error) {
// Done with the pending buffer
nc.bw.doneWithPending()

// This is where we are truly connected.
nc.status = CONNECTED

// Queue up the correct callback. If we are in initial connect state
// (using retry on failed connect), we will call the ConnectedCB,
// otherwise the ReconnectedCB.
Expand Down Expand Up @@ -2930,7 +2975,7 @@ func (nc *Conn) processOpErr(err error) {
// Clear any queued pongs, e.g. pending flush calls.
nc.clearPendingFlushCalls()

go nc.doReconnect(err)
go nc.doReconnect(err, false)
nc.mu.Unlock()
return
}
Expand Down
18 changes: 4 additions & 14 deletions test/conn_test.go
Expand Up @@ -2946,16 +2946,6 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
}

func TestConnStatusChangedEvents(t *testing.T) {
waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) {
select {
case s := <-ch:
if s != expected {
t.Fatalf("Expected status: %s; got: %s", expected, s)
}
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for status %q", expected)
}
}
t.Run("default events", func(t *testing.T) {
s := RunDefaultServer()
nc, err := nats.Connect(s.ClientURL())
Expand All @@ -2978,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) {
time.Sleep(50 * time.Millisecond)

s.Shutdown()
waitForStatus(t, newStatus, nats.RECONNECTING)
WaitOnChannel(t, newStatus, nats.RECONNECTING)

s = RunDefaultServer()
defer s.Shutdown()

waitForStatus(t, newStatus, nats.CONNECTED)
WaitOnChannel(t, newStatus, nats.CONNECTED)

nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
WaitOnChannel(t, newStatus, nats.CLOSED)

select {
case s := <-newStatus:
Expand Down Expand Up @@ -3019,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) {
s = RunDefaultServer()
defer s.Shutdown()
nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
WaitOnChannel(t, newStatus, nats.CLOSED)

select {
case s := <-newStatus:
Expand Down
12 changes: 12 additions & 0 deletions test/helper_test.go
Expand Up @@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error {
return errors.New("timeout")
}

func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) {
t.Helper()
select {
case s := <-ch:
if s != expected {
t.Fatalf("Expected result: %v; got: %v", expected, s)
}
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for result %v", expected)
}
}

func stackFatalf(t tLogger, f string, args ...any) {
lines := make([]string, 0, 32)
msg := fmt.Sprintf(f, args...)
Expand Down
167 changes: 155 additions & 12 deletions test/reconnect_test.go
Expand Up @@ -853,7 +853,7 @@ func TestAuthExpiredReconnect(t *testing.T) {

jwtCB := func() (string, error) {
claims := jwt.NewUserClaims("test")
claims.Expires = time.Now().Add(500 * time.Millisecond).Unix()
claims.Expires = time.Now().Add(time.Second).Unix()
claims.Subject = upub
jwt, err := claims.Encode(akp)
if err != nil {
Expand Down Expand Up @@ -884,21 +884,164 @@ func TestAuthExpiredReconnect(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatal("Did not get the auth expired error")
}
select {
case s := <-stasusCh:
if s != nats.RECONNECTING {
t.Fatalf("Expected to be in reconnecting state after jwt expires, got %v", s)
WaitOnChannel(t, stasusCh, nats.RECONNECTING)
WaitOnChannel(t, stasusCh, nats.CONNECTED)
nc.Close()
}

func TestForceReconnect(t *testing.T) {
s := RunDefaultServer()

nc, err := nats.Connect(s.ClientURL(), nats.ReconnectWait(10*time.Second))
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
// defer nc.Close()

statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED)
defer close(statusCh)
newStatus := make(chan nats.Status, 10)
// non-blocking channel, so we need to be constantly listening
go func() {
for {
s, ok := <-statusCh
if !ok {
return
}
newStatus <- s
}
case <-time.After(2 * time.Second):
t.Fatal("Did not get the status change")
}()

sub, err := nc.SubscribeSync("foo")
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}
if err := nc.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Error on publish: %v", err)
}
_, err = sub.NextMsg(time.Second)
if err != nil {
t.Fatalf("Error getting message: %v", err)
}

// Force a reconnect
err = nc.Reconnect()
if err != nil {
t.Fatalf("Unexpected error on reconnect: %v", err)
}

WaitOnChannel(t, newStatus, nats.RECONNECTING)
WaitOnChannel(t, newStatus, nats.CONNECTED)

if err := nc.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Error on publish: %v", err)
}
_, err = sub.NextMsg(time.Second)
if err != nil {
t.Fatalf("Error getting message: %v", err)
}

// shutdown server and then force a reconnect
s.Shutdown()
WaitOnChannel(t, newStatus, nats.RECONNECTING)
_, err = sub.NextMsg(100 * time.Millisecond)
if err == nil {
t.Fatal("Expected error getting message")
}

// restart server
s = RunDefaultServer()
defer s.Shutdown()

if err := nc.Reconnect(); err != nil {
t.Fatalf("Unexpected error on reconnect: %v", err)
}
// wait for the reconnect
// because the connection has long ReconnectWait,
// if force reconnect does not work, the test will timeout
WaitOnChannel(t, newStatus, nats.CONNECTED)

if err := nc.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Error on publish: %v", err)
}
_, err = sub.NextMsg(time.Second)
if err != nil {
t.Fatalf("Error getting message: %v", err)
}
nc.Close()
}

func TestAuthExpiredForceReconnect(t *testing.T) {
ts := runTrustServer()
defer ts.Shutdown()

_, err := nats.Connect(ts.ClientURL())
if err == nil {
t.Fatalf("Expecting an error on connect")
}
ukp, err := nkeys.FromSeed(uSeed)
if err != nil {
t.Fatalf("Error creating user key pair: %v", err)
}
upub, err := ukp.PublicKey()
if err != nil {
t.Fatalf("Error getting user public key: %v", err)
}
akp, err := nkeys.FromSeed(aSeed)
if err != nil {
t.Fatalf("Error creating account key pair: %v", err)
}

jwtCB := func() (string, error) {
claims := jwt.NewUserClaims("test")
claims.Expires = time.Now().Add(time.Second).Unix()
claims.Subject = upub
jwt, err := claims.Encode(akp)
if err != nil {
return "", err
}
return jwt, nil
}
sigCB := func(nonce []byte) ([]byte, error) {
kp, _ := nkeys.FromSeed(uSeed)
sig, _ := kp.Sign(nonce)
return sig, nil
}

errCh := make(chan error, 1)
nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(10*time.Second),
nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
errCh <- err
}))
if err != nil {
t.Fatalf("Expected to connect, got %v", err)
}
defer nc.Close()
statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED)
defer close(statusCh)
newStatus := make(chan nats.Status, 10)
// non-blocking channel, so we need to be constantly listening
go func() {
for {
s, ok := <-statusCh
if !ok {
return
}
newStatus <- s
}
}()
time.Sleep(100 * time.Millisecond)
select {
case s := <-stasusCh:
if s != nats.CONNECTED {
t.Fatalf("Expected to reconnect, got %v", s)
case err := <-errCh:
if !errors.Is(err, nats.ErrAuthExpired) {
t.Fatalf("Expected auth expired error, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Did not get the status change")
t.Fatal("Did not get the auth expired error")
}
nc.Close()
if err := nc.Reconnect(); err != nil {
t.Fatalf("Unexpected error on reconnect: %v", err)
}
WaitOnChannel(t, newStatus, nats.RECONNECTING)
WaitOnChannel(t, newStatus, nats.CONNECTED)
}

0 comments on commit 3242548

Please sign in to comment.