From 2b1657adc501317e161e765c76248d126cbee618 Mon Sep 17 00:00:00 2001 From: Matt Brittan Date: Wed, 10 Aug 2022 15:01:18 +1200 Subject: [PATCH] Rearchitect status handling as per issue 605 --- client.go | 336 ++++++++++++++++++-------------- fvt_client_test.go | 62 +++--- net.go | 11 +- ping.go | 4 +- status.go | 296 +++++++++++++++++++++++++++++ unit_client_test.go | 8 +- unit_messageids_test.go | 6 +- unit_status_test.go | 411 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 952 insertions(+), 182 deletions(-) create mode 100644 status.go create mode 100644 unit_status_test.go diff --git a/client.go b/client.go index 959a30b..200daa8 100644 --- a/client.go +++ b/client.go @@ -38,13 +38,6 @@ import ( "github.com/eclipse/paho.mqtt.golang/packets" ) -const ( - disconnected uint32 = iota - connecting - reconnecting - connected -) - // Client is the interface definition for a Client as used by this // library, the interface is primarily to allow mocking tests. // @@ -52,9 +45,12 @@ const ( // with an MQTT server using non-blocking methods that allow work // to be done in the background. // An application may connect to an MQTT server using: -// A plain TCP socket -// A secure SSL/TLS socket -// A websocket +// +// A plain TCP socket (e.g. mqtt://test.mosquitto.org:1833) +// A secure SSL/TLS socket (e.g. tls://test.mosquitto.org:8883) +// A websocket (e.g ws://test.mosquitto.org:8080 or wss://test.mosquitto.org:8081) +// Something else (using `options.CustomOpenConnectionFn`) +// // To enable ensured message delivery at Quality of Service (QoS) levels // described in the MQTT spec, a message persistence mechanism must be // used. This is done by providing a type which implements the Store @@ -128,8 +124,7 @@ type client struct { lastReceived atomic.Value // time.Time - the last time a packet was successfully received from network pingOutstanding int32 // set to 1 if a ping has been sent but response not ret received - status uint32 // see const definitions at top of file for possible values - sync.RWMutex // Protects the above two variables (note: atomic writes are also used somewhat inconsistently) + status connectionStatus // see constants in status.go for values messageIds // effectively a map from message id to token completor @@ -169,7 +164,6 @@ func NewClient(o *ClientOptions) Client { c.options.protocolVersionExplicit = false } c.persist = c.options.Store - c.status = disconnected c.messageIds = messageIds{index: make(map[uint16]tokenCompletor)} c.msgRouter = newRouter() c.msgRouter.setDefaultHandler(c.options.DefaultPublishHandler) @@ -196,47 +190,27 @@ func (c *client) AddRoute(topic string, callback MessageHandler) { // the client is connected or not. // connected means that the connection is up now OR it will // be established/reestablished automatically when possible +// Warning: The connection status may change at any time so use this with care! func (c *client) IsConnected() bool { - c.RLock() - defer c.RUnlock() - status := atomic.LoadUint32(&c.status) + // This will need to change if additional statuses are added + s, r := c.status.ConnectionStatusRetry() switch { - case status == connected: - return true - case c.options.AutoReconnect && status > connecting: + case s == connected: return true - case c.options.ConnectRetry && status == connecting: + case c.options.ConnectRetry && s == connecting: return true + case c.options.AutoReconnect: + return s == reconnecting || (s == disconnecting && r) // r indicates we will reconnect default: return false } } // IsConnectionOpen return a bool signifying whether the client has an active -// connection to mqtt broker, i.e not in disconnected or reconnect mode +// connection to mqtt broker, i.e. not in disconnected or reconnect mode +// Warning: The connection status may change at any time so use this with care! func (c *client) IsConnectionOpen() bool { - c.RLock() - defer c.RUnlock() - status := atomic.LoadUint32(&c.status) - switch { - case status == connected: - return true - default: - return false - } -} - -func (c *client) connectionStatus() uint32 { - c.RLock() - defer c.RUnlock() - status := atomic.LoadUint32(&c.status) - return status -} - -func (c *client) setConnected(status uint32) { - c.Lock() - defer c.Unlock() - atomic.StoreUint32(&c.status, status) + return c.status.ConnectionStatus() == connected } // ErrNotConnected is the error returned from function calls that are @@ -253,25 +227,31 @@ func (c *client) Connect() Token { t := newToken(packets.Connect).(*ConnectToken) DEBUG.Println(CLI, "Connect()") - if c.options.ConnectRetry && atomic.LoadUint32(&c.status) != disconnected { - // if in any state other than disconnected and ConnectRetry is - // enabled then the connection will come up automatically - // client can assume connection is up - WARN.Println(CLI, "Connect() called but not disconnected") - t.returnCode = packets.Accepted - t.flowComplete() + connectionUp, err := c.status.Connecting() + if err != nil { + if err == errAlreadyConnectedOrReconnecting && c.options.AutoReconnect { + // When reconnection is active we don't consider calls tro Connect to ba an error (mainly for compatability) + WARN.Println(CLI, "Connect() called but not disconnected") + t.returnCode = packets.Accepted + t.flowComplete() + return t + } + ERROR.Println(CLI, err) // CONNECT should never be called unless we are disconnected + t.setError(err) return t } c.persist.Open() if c.options.ConnectRetry { - c.reserveStoredPublishIDs() // Reserve IDs to allow publish before connect complete + c.reserveStoredPublishIDs() // Reserve IDs to allow publishing before connect complete } - c.setConnected(connecting) go func() { if len(c.options.Servers) == 0 { t.setError(fmt.Errorf("no servers defined to connect to")) + if err := connectionUp(false); err != nil { + ERROR.Println(CLI, err.Error()) + } return } @@ -285,26 +265,28 @@ func (c *client) Connect() Token { DEBUG.Println(CLI, "Connect failed, sleeping for", int(c.options.ConnectRetryInterval.Seconds()), "seconds and will then retry, error:", err.Error()) time.Sleep(c.options.ConnectRetryInterval) - if atomic.LoadUint32(&c.status) == connecting { + if c.status.ConnectionStatus() == connecting { // Possible connection aborted elsewhere goto RETRYCONN } } ERROR.Println(CLI, "Failed to connect to a broker") - c.setConnected(disconnected) c.persist.Close() t.returnCode = rc t.setError(err) + if err := connectionUp(false); err != nil { + ERROR.Println(CLI, err.Error()) + } return } - inboundFromStore := make(chan packets.ControlPacket) // there may be some inbound comms packets in the store that are awaiting processing - if c.startCommsWorkers(conn, inboundFromStore) { + inboundFromStore := make(chan packets.ControlPacket) // there may be some inbound comms packets in the store that are awaiting processing + if c.startCommsWorkers(conn, connectionUp, inboundFromStore) { // note that this takes care of updating the status (to connected or disconnected) // Take care of any messages in the store if !c.options.CleanSession { c.resume(c.options.ResumeSubs, inboundFromStore) } else { c.persist.Reset() } - } else { + } else { // Note: With the new status subsystem this should only happen if Disconnect called simultaneously with the above WARN.Println(CLI, "Connect() called but connection established in another goroutine") } @@ -316,7 +298,8 @@ func (c *client) Connect() Token { } // internal function used to reconnect the client when it loses its connection -func (c *client) reconnect() { +// The connection status MUST be reconnecting prior to calling this function (via call to status.connectionLost) +func (c *client) reconnect(connectionUp connCompletedFn) { DEBUG.Println(CLI, "enter reconnect") var ( sleep = 1 * time.Second @@ -341,23 +324,18 @@ func (c *client) reconnect() { if sleep > c.options.MaxReconnectInterval { sleep = c.options.MaxReconnectInterval } - // Disconnect may have been called - if atomic.LoadUint32(&c.status) == disconnected { - break - } - } - // Disconnect() must have been called while we were trying to reconnect. - if c.connectionStatus() == disconnected { - if conn != nil { - conn.Close() + if c.status.ConnectionStatus() != reconnecting { // Disconnect may have been called + if err := connectionUp(false); err != nil { // Should always return an error + ERROR.Println(CLI, err.Error()) + } + DEBUG.Println(CLI, "Client moved to disconnected state while reconnecting, abandoning reconnect") + return } - DEBUG.Println(CLI, "Client moved to disconnected state while reconnecting, abandoning reconnect") - return } - inboundFromStore := make(chan packets.ControlPacket) // there may be some inbound comms packets in the store that are awaiting processing - if c.startCommsWorkers(conn, inboundFromStore) { + inboundFromStore := make(chan packets.ControlPacket) // there may be some inbound comms packets in the store that are awaiting processing + if c.startCommsWorkers(conn, connectionUp, inboundFromStore) { // note that this takes care of updating the status (to connected or disconnected) c.resume(c.options.ResumeSubs, inboundFromStore) } close(inboundFromStore) @@ -417,7 +395,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) { ERROR.Println(CLI, "set deadline for handshake ", err) } - // Now we send the perform the MQTT connection handshake + // Now we perform the MQTT connection handshake rc, sessionPresent, err = connectMQTT(conn, cm, protocolVersion) if rc == packets.Accepted { if err := conn.SetDeadline(time.Time{}); err != nil { @@ -460,43 +438,59 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) { // reusing the `client` may lead to panics. If you want to reconnect when the connection drops then use // `SetAutoReconnect` and/or `SetConnectRetry`options instead of implementing this yourself. func (c *client) Disconnect(quiesce uint) { - defer c.disconnect() - - status := atomic.LoadUint32(&c.status) - c.setConnected(disconnected) - - if status != connected { - WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)") - return - } + done := make(chan struct{}) // Simplest way to ensure quiesce is always honoured + go func() { + defer close(done) + disDone, err := c.status.Disconnecting() + if err != nil { + // Status has been set to disconnecting, but we had to wait for something else to complete + WARN.Println(CLI, err.Error()) + return + } + defer func() { + c.disconnect() // Force disconnection + disDone() // Update status + }() + DEBUG.Println(CLI, "disconnecting") + dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) + dt := newToken(packets.Disconnect) + select { + case c.oboundP <- &PacketAndToken{p: dm, t: dt}: + // wait for work to finish, or quiesce time consumed + DEBUG.Println(CLI, "calling WaitTimeout") + dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond) + DEBUG.Println(CLI, "WaitTimeout done") + // Below code causes a potential data race. Following status refactor it should no longer be required + // but leaving in as need to check code further. + // case <-c.commsStopped: + // WARN.Println("Disconnect packet could not be sent because comms stopped") + case <-time.After(time.Duration(quiesce) * time.Millisecond): + WARN.Println("Disconnect packet not sent due to timeout") + } + }() - DEBUG.Println(CLI, "disconnecting") - dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) - dt := newToken(packets.Disconnect) + // Return when done or after timeout expires (would like to change but this maintains compatibility) + delay := time.NewTimer(time.Duration(quiesce) * time.Millisecond) select { - case c.oboundP <- &PacketAndToken{p: dm, t: dt}: - // wait for work to finish, or quiesce time consumed - DEBUG.Println(CLI, "calling WaitTimeout") - dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond) - DEBUG.Println(CLI, "WaitTimeout done") - // Let's comment this chunk of code until we are able to safely read this variable - // without data races. - // case <-c.commsStopped: - // WARN.Println("Disconnect packet could not be sent because comms stopped") - case <-time.After(time.Duration(quiesce) * time.Millisecond): - WARN.Println("Disconnect packet not sent due to timeout") + case <-done: + if !delay.Stop() { + <-delay.C + } + case <-delay.C: } } // forceDisconnect will end the connection with the mqtt broker immediately (used for tests only) func (c *client) forceDisconnect() { - if !c.IsConnected() { - WARN.Println(CLI, "already disconnected") + disDone, err := c.status.Disconnecting() + if err != nil { + // Possible that we are not actually connected + WARN.Println(CLI, err.Error()) return } - c.setConnected(disconnected) DEBUG.Println(CLI, "forcefully disconnecting") c.disconnect() + disDone() } // disconnect cleans up after a final disconnection (user requested so no auto reconnection) @@ -513,49 +507,79 @@ func (c *client) disconnect() { // internalConnLost cleanup when connection is lost or an error occurs // Note: This function will not block -func (c *client) internalConnLost(err error) { +func (c *client) internalConnLost(whyConnLost error) { // It is possible that internalConnLost will be called multiple times simultaneously // (including after sending a DisconnectPacket) as such we only do cleanup etc if the // routines were actually running and are not being disconnected at users request DEBUG.Println(CLI, "internalConnLost called") + disDone, err := c.status.ConnectionLost(c.options.AutoReconnect && c.status.ConnectionStatus() > connecting) + if err != nil { + if err == errConnLossWhileDisconnecting || err == errAlreadyHandlingConnectionLoss { + return // Loss of connection is expected or already being handled + } + ERROR.Println(CLI, fmt.Sprintf("internalConnLost unexpected status: %s", err.Error())) + return + } + + // c.stopCommsWorker returns a channel that is closed when the operation completes. This was required prior + // to the implementation of proper status management but has been left in place, for now, to minimise change stopDone := c.stopCommsWorkers() - if stopDone != nil { // stopDone will be nil if workers already in the process of stopping or stopped - go func() { - DEBUG.Println(CLI, "internalConnLost waiting on workers") - <-stopDone - DEBUG.Println(CLI, "internalConnLost workers stopped") - // It is possible that Disconnect was called which led to this error so reconnection depends upon status - reconnect := c.options.AutoReconnect && c.connectionStatus() > connecting - - if c.options.CleanSession && !reconnect { - c.messageIds.cleanUp() // completes PUB/SUB/UNSUB tokens - } else if !c.options.ResumeSubs { - c.messageIds.cleanUpSubscribe() // completes SUB/UNSUB tokens - } - if reconnect { - c.setConnected(reconnecting) - go c.reconnect() - } else { - c.setConnected(disconnected) - } - if c.options.OnConnectionLost != nil { - go c.options.OnConnectionLost(c, err) - } - DEBUG.Println(CLI, "internalConnLost complete") - }() + // stopDone was required in previous versions because there was no connectionLost status (and there were + // issues with status handling). This code has been left in place for the time being just in case the new + // status handling contains bugs (refactoring required at some point). + if stopDone == nil { // stopDone will be nil if workers already in the process of stopping or stopped + ERROR.Println(CLI, "internalConnLost stopDone unexpectedly nil - BUG BUG") + // Cannot really do anything other than leave things disconnected + if _, err = disDone(false); err != nil { // Safest option - cannot leave status as connectionLost + ERROR.Println(CLI, fmt.Sprintf("internalConnLost failed to set status to disconnected (stopDone): %s", err.Error())) + } + return } + + // It may take a while for the disconnection to complete whatever called us needs to exit cleanly so finnish in goRoutine + go func() { + DEBUG.Println(CLI, "internalConnLost waiting on workers") + <-stopDone + DEBUG.Println(CLI, "internalConnLost workers stopped") + + reConnDone, err := disDone(true) + if err != nil { + ERROR.Println(CLI, "failure whilst reporting completion of disconnect", err) + } else if reConnDone == nil { // Should never happen + ERROR.Println(CLI, "BUG BUG BUG reconnection function is nil", err) + } + + reconnect := err == nil && reConnDone != nil + + if c.options.CleanSession && !reconnect { + c.messageIds.cleanUp() // completes PUB/SUB/UNSUB tokens + } else if !c.options.ResumeSubs { + c.messageIds.cleanUpSubscribe() // completes SUB/UNSUB tokens + } + if reconnect { + go c.reconnect(reConnDone) // Will set connection status to reconnecting + } + if c.options.OnConnectionLost != nil { + go c.options.OnConnectionLost(c, whyConnLost) + } + DEBUG.Println(CLI, "internalConnLost complete") + }() } // startCommsWorkers is called when the connection is up. -// It starts off all of the routines needed to process incoming and outgoing messages. -// Returns true if the comms workers were started (i.e. they were not already running) -func (c *client) startCommsWorkers(conn net.Conn, inboundFromStore <-chan packets.ControlPacket) bool { +// It starts off the routines needed to process incoming and outgoing messages. +// Returns true if the comms workers were started (i.e. successful connection) +// connectionUp(true) will be called once everything is up; connectionUp(false) will be called on failure +func (c *client) startCommsWorkers(conn net.Conn, connectionUp connCompletedFn, inboundFromStore <-chan packets.ControlPacket) bool { DEBUG.Println(CLI, "startCommsWorkers called") c.connMu.Lock() defer c.connMu.Unlock() - if c.conn != nil { - WARN.Println(CLI, "startCommsWorkers called when commsworkers already running") - conn.Close() // No use for the new network connection + if c.conn != nil { // Should never happen due to new status handling; leaving in for safety for the time being + WARN.Println(CLI, "startCommsWorkers called when commsworkers already running BUG BUG") + _ = conn.Close() // No use for the new network connection + if err := connectionUp(false); err != nil { + ERROR.Println(CLI, err.Error()) + } return false } c.conn = conn // Store the connection @@ -575,7 +599,17 @@ func (c *client) startCommsWorkers(conn net.Conn, inboundFromStore <-chan packet c.workers.Add(1) // Done will be called when ackOut is closed ackOut := c.msgRouter.matchAndDispatch(incomingPubChan, c.options.Order, c) - c.setConnected(connected) + // The connection is now ready for use (we spin up a few go routines below). It is possible that + // Disconnect has been called in the interim... + if err := connectionUp(true); err != nil { + DEBUG.Println(CLI, err) + close(c.stop) // Tidy up anything we have already started + close(incomingPubChan) + c.workers.Wait() + c.conn.Close() + c.conn = nil + return false + } DEBUG.Println(CLI, "client is connected/reconnected") if c.options.OnConnect != nil { go c.options.OnConnect(c) @@ -668,8 +702,9 @@ func (c *client) startCommsWorkers(conn net.Conn, inboundFromStore <-chan packet } // stopWorkersAndComms - Cleanly shuts down worker go routines (including the comms routines) and waits until everything has stopped -// Returns nil it workers did not need to be stopped; otherwise returns a channel which will be closed when the stop is complete +// Returns nil if workers did not need to be stopped; otherwise returns a channel which will be closed when the stop is complete // Note: This may block so run as a go routine if calling from any of the comms routines +// Note2: It should be possible to simplify this now that the new status management code is in place. func (c *client) stopCommsWorkers() chan struct{} { DEBUG.Println(CLI, "stopCommsWorkers called") // It is possible that this function will be called multiple times simultaneously due to the way things get shutdown @@ -718,7 +753,8 @@ func (c *client) Publish(topic string, qos byte, retained bool, payload interfac case !c.IsConnected(): token.setError(ErrNotConnected) return token - case c.connectionStatus() == reconnecting && qos == 0: + case c.status.ConnectionStatus() == reconnecting && qos == 0: + // message written to store and will be sent when connection comes up token.flowComplete() return token } @@ -748,11 +784,13 @@ func (c *client) Publish(topic string, qos byte, retained bool, payload interfac token.messageID = mID } persistOutbound(c.persist, pub) - switch c.connectionStatus() { + switch c.status.ConnectionStatus() { case connecting: DEBUG.Println(CLI, "storing publish message (connecting), topic:", topic) case reconnecting: DEBUG.Println(CLI, "storing publish message (reconnecting), topic:", topic) + case disconnecting: + DEBUG.Println(CLI, "storing publish message (disconnecting), topic:", topic) default: DEBUG.Println(CLI, "sending publish message, topic:", topic) publishWaitTimeout := c.options.WriteTimeout @@ -785,11 +823,11 @@ func (c *client) Subscribe(topic string, qos byte, callback MessageHandler) Toke if !c.IsConnectionOpen() { switch { case !c.options.ResumeSubs: - // if not connected and resumesubs not set this sub will be thrown away + // if not connected and resumeSubs not set this sub will be thrown away token.setError(fmt.Errorf("not currently connected and ResumeSubs not set")) return token - case c.options.CleanSession && c.connectionStatus() == reconnecting: - // if reconnecting and cleansession is true this sub will be thrown away + case c.options.CleanSession && c.status.ConnectionStatus() == reconnecting: + // if reconnecting and cleanSession is true this sub will be thrown away token.setError(fmt.Errorf("reconnecting state and cleansession is true")) return token } @@ -830,11 +868,13 @@ func (c *client) Subscribe(topic string, qos byte, callback MessageHandler) Toke if c.options.ResumeSubs { // Only persist if we need this to resume subs after a disconnection persistOutbound(c.persist, sub) } - switch c.connectionStatus() { + switch c.status.ConnectionStatus() { case connecting: DEBUG.Println(CLI, "storing subscribe message (connecting), topic:", topic) case reconnecting: DEBUG.Println(CLI, "storing subscribe message (reconnecting), topic:", topic) + case disconnecting: + DEBUG.Println(CLI, "storing subscribe message (disconnecting), topic:", topic) default: DEBUG.Println(CLI, "sending subscribe message, topic:", topic) subscribeWaitTimeout := c.options.WriteTimeout @@ -872,8 +912,8 @@ func (c *client) SubscribeMultiple(filters map[string]byte, callback MessageHand // if not connected and resumesubs not set this sub will be thrown away token.setError(fmt.Errorf("not currently connected and ResumeSubs not set")) return token - case c.options.CleanSession && c.connectionStatus() == reconnecting: - // if reconnecting and cleansession is true this sub will be thrown away + case c.options.CleanSession && c.status.ConnectionStatus() == reconnecting: + // if reconnecting and cleanSession is true this sub will be thrown away token.setError(fmt.Errorf("reconnecting state and cleansession is true")) return token } @@ -904,11 +944,13 @@ func (c *client) SubscribeMultiple(filters map[string]byte, callback MessageHand if c.options.ResumeSubs { // Only persist if we need this to resume subs after a disconnection persistOutbound(c.persist, sub) } - switch c.connectionStatus() { + switch c.status.ConnectionStatus() { case connecting: DEBUG.Println(CLI, "storing subscribe message (connecting), topics:", sub.Topics) case reconnecting: DEBUG.Println(CLI, "storing subscribe message (reconnecting), topics:", sub.Topics) + case disconnecting: + DEBUG.Println(CLI, "storing subscribe message (disconnecting), topics:", sub.Topics) default: DEBUG.Println(CLI, "sending subscribe message, topics:", sub.Topics) subscribeWaitTimeout := c.options.WriteTimeout @@ -1058,7 +1100,7 @@ func (c *client) resume(subscription bool, ibound chan packets.ControlPacket) { } releaseSemaphore(token) // If limiting simultaneous messages then we need to know when message is acknowledged default: - ERROR.Println(STR, "invalid message type in store (discarded)") + ERROR.Println(STR, fmt.Sprintf("invalid message type (inbound - %T) in store (discarded)", packet)) c.persist.Del(key) } } else { @@ -1072,7 +1114,7 @@ func (c *client) resume(subscription bool, ibound chan packets.ControlPacket) { return } default: - ERROR.Println(STR, "invalid message type in store (discarded)") + ERROR.Println(STR, fmt.Sprintf("invalid message type (%T) in store (discarded)", packet)) c.persist.Del(key) } } @@ -1093,11 +1135,11 @@ func (c *client) Unsubscribe(topics ...string) Token { if !c.IsConnectionOpen() { switch { case !c.options.ResumeSubs: - // if not connected and resumesubs not set this unsub will be thrown away + // if not connected and resumeSubs not set this unsub will be thrown away token.setError(fmt.Errorf("not currently connected and ResumeSubs not set")) return token - case c.options.CleanSession && c.connectionStatus() == reconnecting: - // if reconnecting and cleansession is true this unsub will be thrown away + case c.options.CleanSession && c.status.ConnectionStatus() == reconnecting: + // if reconnecting and cleanSession is true this unsub will be thrown away token.setError(fmt.Errorf("reconnecting state and cleansession is true")) return token } @@ -1120,11 +1162,13 @@ func (c *client) Unsubscribe(topics ...string) Token { persistOutbound(c.persist, unsub) } - switch c.connectionStatus() { + switch c.status.ConnectionStatus() { case connecting: DEBUG.Println(CLI, "storing unsubscribe message (connecting), topics:", topics) case reconnecting: DEBUG.Println(CLI, "storing unsubscribe message (reconnecting), topics:", topics) + case disconnecting: + DEBUG.Println(CLI, "storing unsubscribe message (reconnecting), topics:", topics) default: DEBUG.Println(CLI, "sending unsubscribe message, topics:", topics) subscribeWaitTimeout := c.options.WriteTimeout diff --git a/fvt_client_test.go b/fvt_client_test.go index 22b8a56..d266009 100644 --- a/fvt_client_test.go +++ b/fvt_client_test.go @@ -126,7 +126,7 @@ func Test_Disconnect(t *testing.T) { go func() { c.Disconnect(250) cli := c.(*client) - cli.status = connected + cli.status.forceConnectionStatus(connected) c.Disconnect(250) close(disconnectC) }() @@ -1191,29 +1191,36 @@ func Test_cleanUpMids_2(t *testing.T) { ops.SetKeepAlive(10 * time.Second) c := NewClient(ops) + cl := c.(*client) if token := c.Connect(); token.Wait() && token.Error() != nil { t.Fatalf("Error on Client.Connect(): %v", token.Error()) } token := c.Publish("/test/cleanUP", 2, false, "cleanup test 2") - if len(c.(*client).messageIds.index) == 0 { + cl.messageIds.mu.Lock() + mq := len(c.(*client).messageIds.index) + cl.messageIds.mu.Unlock() + if mq == 0 { t.Fatalf("Should be a token in the messageIDs, none found") } - fmt.Println("Disconnecting", len(c.(*client).messageIds.index)) + // fmt.Println("Disconnecting", len(cl.messageIds.index)) c.Disconnect(0) fmt.Println("Wait on Token") // We should be able to wait on this token without any issue token.Wait() - if len(c.(*client).messageIds.index) > 0 { + cl.messageIds.mu.Lock() + mq = len(c.(*client).messageIds.index) + cl.messageIds.mu.Unlock() + if mq > 0 { t.Fatalf("Should have cleaned up messageIDs, have %d left", len(c.(*client).messageIds.index)) } if token.Error() == nil { t.Fatal("token should have received an error on connection loss") } - fmt.Println(token.Error()) + // fmt.Println(token.Error()) } func Test_ConnectRetry(t *testing.T) { @@ -1339,7 +1346,6 @@ func Test_ResumeSubs(t *testing.T) { t.Fatalf("Expected 1 packet to be in store") } packet := subMemStore.Get(ids[0]) - fmt.Println("packet", packet) if packet == nil { t.Fatal("Failed to retrieve packet from store") } @@ -1471,11 +1477,12 @@ func Test_ResumeSubsWithReconnect(t *testing.T) { c.Disconnect(250) } -// Issue 209 - occasional deadlock when connections are lost unexpectedly +// Issue 509 - occasional deadlock when connections are lost unexpectedly // This was quite a nasty deadlock which occurred in very rare circumstances; I could not come up with a reliable way of // replicating this but the below would cause it to happen fairly consistently (when the test was run a decent number // of times). Following the fix it ran 10,000 times without issue. -// go test -count 10000 -run DisconnectWhileProcessingIncomingPublish +// +// go test -count 10000 -run DisconnectWhileProcessingIncomingPublish func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { topic := "/test/DisconnectWhileProcessingIncomingPublish" @@ -1487,11 +1494,11 @@ func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { sops := NewClientOptions() sops.AddBroker(FVTTCP) - sops.SetAutoReconnect(false) // We dont want the connection to be re-established + sops.SetAutoReconnect(false) // We don't want the connection to be re-established sops.SetWriteTimeout(500 * time.Millisecond) // We will be sending a lot of publish messages and want go routines to clear... // sops.SetOrderMatters(false) sops.SetClientID("dwpip-sub") - // We need to know when the subscriber has lost its connection (this indicates that the deadlock has not occured) + // We need to know when the subscriber has lost its connection (this indicates that the deadlock has not occurred) sDisconnected := make(chan struct{}) sops.SetConnectionLostHandler(func(Client, error) { close(sDisconnected) }) @@ -1523,10 +1530,9 @@ func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { i := 0 for { p.Publish(topic, 1, false, fmt.Sprintf("test message: %d", i)) - // After the connection goes down s.Publish will start blocking (this is not ideal but fixing its a problem for another time) - go func() { s.Publish(topic+"IGNORE", 1, false, fmt.Sprintf("test message: %d", i)) }() + // After the connection goes down s.Publish will start blocking (this is not ideal but fixing it's a problem for another time) + go func(i int) { s.Publish(topic+"IGNORE", 1, false, fmt.Sprintf("test message: %d", i)) }(i) i++ - if ctx.Err() != nil { return } @@ -1534,9 +1540,13 @@ func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { }() // Wait until we have received a message (ensuring that the stream of messages has started) + delay := time.NewTimer(time.Second) // Be careful with timers as this will be run in a tight loop! select { case <-msgReceived: // All good - case <-time.After(time.Second): + if !delay.Stop() { // Cleanly close timer as this may be run in a tight loop! + <-delay.C + } + case <-delay.C: t.Errorf("no messages received") } @@ -1545,15 +1555,19 @@ func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) err := dm.Write(s.conn) if err != nil { - t.Fatalf("error dending disconnect packet: %s", err) + t.Fatalf("error sending disconnect packet: %s", err) } // Lets give the library up to a second to shutdown (indicated by the status changing) + delay = time.NewTimer(time.Second) // Be careful with timers as this will be run in a tight loop! select { case <-sDisconnected: // All good - case <-time.After(time.Second): - cancel() // no point leaving publisher running - time.Sleep(time.Second) // Allow publish calls to timeout (otherwise there will be tons of go routines running!) + if !delay.Stop() { + <-delay.C + } + case <-delay.C: + cancel() // no point leaving publisher running + time.Sleep(10 * time.Second) // Allow publish calls to timeout (otherwise there will be tons of go routines running!) buf := make([]byte, 1<<20) stacklen := runtime.Stack(buf, true) t.Fatalf("connection was not lost as expected - probable deadlock. Stacktrace follows: %s", buf[:stacklen]) @@ -1561,18 +1575,22 @@ func Test_DisconnectWhileProcessingIncomingPublish(t *testing.T) { cancel() // no point leaving publisher running + delay = time.NewTimer(time.Second) // Be careful with timers as this will be run in a tight loop! select { case <-pubDone: - case <-time.After(time.Second): - t.Errorf("pubdone not closed within a second") + if !delay.Stop() { + <-delay.C + } + case <-delay.C: + t.Errorf("pubdone not closed within two seconds (probably due to load on system but may be an issue)") } p.Disconnect(250) // Close publisher } // Test_ResumeSubsMaxInflight - Check the MaxResumePubInFlight option. // This is difficult to test without control of the broker (because we will be communicating via the broker not -// directly. However due to the way resume works when there is no limit to inflight messages message ordering is not -// guaranteed. However with SetMaxResumePubInFlight(1) it is guaranteed so we use that to test. +// directly. However, due to the way resume works when there is no limit to inflight messages message ordering is not +// guaranteed. However, with SetMaxResumePubInFlight(1) it is guaranteed so we use that to test. // On my PC (using mosquitto under docker) running this without SetMaxResumePubInFlight(1) will fail with 1000 messages // (generally passes if only 100 are sent). With the option set it always passes. func Test_ResumeSubsMaxInflight(t *testing.T) { diff --git a/net.go b/net.go index 56dd9e6..10cc7da 100644 --- a/net.go +++ b/net.go @@ -150,7 +150,7 @@ type incomingComms struct { // startIncomingComms initiates incoming communications; this includes starting a goroutine to process incoming // messages. -// Accepts a channel of inbound messages from the store (persisted messages); note this must be closed as soon as the +// Accepts a channel of inbound messages from the store (persisted messages); note this must be closed as soon as // everything in the store has been sent. // Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed) func startIncomingComms(conn io.Reader, @@ -332,7 +332,7 @@ func startOutgoingComms(conn net.Conn, DEBUG.Println(NET, "outbound wrote disconnect, closing connection") // As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection" // Closing the connection will cause the goroutines to end in sequence (starting with incoming comms) - conn.Close() + _ = conn.Close() } case msg, ok := <-oboundFromIncoming: // message triggered by an inbound message (PubrecPacket or PubrelPacket) if !ok { @@ -370,9 +370,10 @@ type commsFns interface { // startComms initiates goroutines that handles communications over the network connection // Messages will be stored (via commsFns) and deleted from the store as necessary // It returns two channels: -// packets.PublishPacket - Will receive publish packets received over the network. -// Closed when incoming comms routines exit (on shutdown or if network link closed) -// error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down +// +// packets.PublishPacket - Will receive publish packets received over the network. +// Closed when incoming comms routines exit (on shutdown or if network link closed) +// error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down // // Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the // connection being closed and those channels being closed will generate errors (and nothing will be sent). That way the chance of a deadlock is diff --git a/ping.go b/ping.go index 63283bf..857aa0e 100644 --- a/ping.go +++ b/ping.go @@ -58,8 +58,8 @@ func keepalive(c *client, conn io.Writer) { if atomic.LoadInt32(&c.pingOutstanding) == 0 { DEBUG.Println(PNG, "keepalive sending ping") ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket) - // We don't want to wait behind large messages being sent, the Write call - // will block until it it able to send the packet. + // We don't want to wait behind large messages being sent, the `Write` call + // will block until it is able to send the packet. atomic.StoreInt32(&c.pingOutstanding, 1) if err := ping.Write(conn); err != nil { ERROR.Println(PNG, err) diff --git a/status.go b/status.go new file mode 100644 index 0000000..d25fbf5 --- /dev/null +++ b/status.go @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2021 IBM Corp and others. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Public License v2.0 + * and Eclipse Distribution License v1.0 which accompany this distribution. + * + * The Eclipse Public License is available at + * https://www.eclipse.org/legal/epl-2.0/ + * and the Eclipse Distribution License is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * Contributors: + * Seth Hoenig + * Allan Stockdill-Mander + * Mike Robertson + * Matt Brittan + */ + +package mqtt + +import ( + "errors" + "sync" +) + +// Status - Manage the connection status + +// Multiple go routines will want to access/set this. Previously status was implemented as a `uint32` and updated +// with a mixture of atomic functions and a mutex (leading to some deadlock type issues that were very hard to debug). + +// In this new implementation `connectionStatus` takes over managing the state and provides functions that allow the +// client to request a move to a particular state (it may reject these requests!). In some cases the 'state' is +// transitory, for example `connecting`, in those cases a function will be returned that allows the client to move +// to a more static state (`disconnected` or `connected`). + +// This "belts-and-braces" may be a little over the top but issues with the status have caused a number of difficult +// to trace bugs in the past and the likelihood that introducing a new system would introduce bugs seemed high! +// I have written this in a way that should make it very difficult to misuse it (but it does make things a little +// complex with functions returning functions that return functions!). + +type status uint32 + +const ( + disconnected status = iota // default (nil) status is disconnected + disconnecting // Transitioning from one of the below states back to disconnected + connecting + reconnecting + connected +) + +// String simplify output of statuses +func (s status) String() string { + switch s { + case disconnected: + return "disconnected" + case disconnecting: + return "disconnecting" + case connecting: + return "connecting" + case reconnecting: + return "reconnecting" + case connected: + return "connected" + default: + return "invalid" + } +} + +type connCompletedFn func(success bool) error +type disconnectCompletedFn func() +type connectionLostHandledFn func(bool) (connCompletedFn, error) + +/* State transitions + +static states are `disconnected` and `connected`. For all other states a process will hold a function that will move +the state to one of those. That function effectively owns the state and any other changes must not proceed until it +completes. One exception to that is that the state can always be moved to `disconnecting` which provides a signal that +transitions to `connected` will be rejected (this is required because a Disconnect can be requested while in the +Connecting state). + +# Basic Operations + +The standard workflows are: + +disconnected -> `Connecting()` -> connecting -> `connCompletedFn(true)` -> connected +connected -> `Disconnecting()` -> disconnecting -> `disconnectCompletedFn()` -> disconnected +connected -> `ConnectionLost(false)` -> disconnecting -> `connectionLostHandledFn(true/false)` -> disconnected +connected -> `ConnectionLost(true)` -> disconnecting -> `connectionLostHandledFn(true)` -> connected + +Unfortunately the above workflows are complicated by the fact that `Disconnecting()` or `ConnectionLost()` may, +potentially, be called at any time (i.e. whilst in the middle of transitioning between states). If this happens: + +* The state will be set to disconnecting (which will prevent any request to move the status to connected) +* The call to `Disconnecting()`/`ConnectionLost()` will block until the previously active call completes and then + handle the disconnection. + +Reading the tests (unit_status_test.go) might help understand these rules. +*/ + +var ( + errAbortConnection = errors.New("disconnect called whist connection attempt in progress") + errAlreadyConnectedOrReconnecting = errors.New("status is already connected or reconnecting") + errStatusMustBeDisconnected = errors.New("status can only transition to connecting from disconnected") + errAlreadyDisconnected = errors.New("status is already disconnected") + errDisconnectionRequested = errors.New("disconnection was requested whilst the action was in progress") + errDisconnectionInProgress = errors.New("disconnection already in progress") + errAlreadyHandlingConnectionLoss = errors.New("status is already Connection Lost") + errConnLossWhileDisconnecting = errors.New("connection status is disconnecting so loss of connection is expected") +) + +// connectionStatus encapsulates, and protects, the connection status. +type connectionStatus struct { + sync.RWMutex // Protects the variables below + status status + willReconnect bool // only used when status == disconnecting. Indicates that an attempt will be made to reconnect (allows us to abort that) + + // Some statuses are transitional (e.g. connecting, connectionLost, reconnecting, disconnecting), that is, whatever + // process moves us into that status will move us out of it when an action is complete. Sometimes other users + // will need to know when the action is complete (e.g. the user calls `Disconnect()` whilst the status is + // `connecting`). `actionCompleted` will be set whenever we move into one of the above statues and the channel + // returned to anything else requesting a status change. The channel will be closed when the operation is complete. + actionCompleted chan struct{} // Only valid whilst status is Connecting or Reconnecting; will be closed when connection completed (success or failure) +} + +// ConnectionStatus returns the connection status. +// WARNING: the status may change at any time so users should not assume they are the only goroutine touching this +func (c *connectionStatus) ConnectionStatus() status { + c.RLock() + defer c.RUnlock() + return c.status +} + +// ConnectionStatusRetry returns the connection status and retry flag (indicates that we expect to reconnect). +// WARNING: the status may change at any time so users should not assume they are the only goroutine touching this +func (c *connectionStatus) ConnectionStatusRetry() (status, bool) { + c.RLock() + defer c.RUnlock() + return c.status, c.willReconnect +} + +// Connecting - Changes the status to connecting if that is a permitted operation +// Will do nothing unless the current status is disconnected +// Returns a function that MUST be called when the operation is complete (pass in true if successful) +func (c *connectionStatus) Connecting() (connCompletedFn, error) { + c.Lock() + defer c.Unlock() + // Calling Connect when already connecting (or if reconnecting) may not always be considered an error + if c.status == connected || c.status == reconnecting { + return nil, errAlreadyConnectedOrReconnecting + } + if c.status != disconnected { + return nil, errStatusMustBeDisconnected + } + c.status = connecting + c.actionCompleted = make(chan struct{}) + return c.connected, nil +} + +// connected is an internal function (it is returned by functions that set the status to connecting or reconnecting, +// calling it completes the operation). `success` is used to indicate whether the operation was successfully completed. +func (c *connectionStatus) connected(success bool) error { + c.Lock() + defer func() { + close(c.actionCompleted) // Alert anything waiting on the connection process to complete + c.actionCompleted = nil // Be tidy + c.Unlock() + }() + + // Status may have moved to disconnecting in the interim (i.e. at users request) + if c.status == disconnecting { + return errAbortConnection + } + if success { + c.status = connected + } else { + c.status = disconnected + } + return nil +} + +// Disconnecting - should be called when beginning the disconnection process (cleanup etc.). +// Can be called from ANY status and the end result will always be a status of disconnected +// Note that if a connection/reconnection attempt is in progress this function will set the status to `disconnecting` +// then block until the connection process completes (or aborts). +// Returns a function that MUST be called when the operation is complete (assumed to always be successful!) +func (c *connectionStatus) Disconnecting() (disconnectCompletedFn, error) { + c.Lock() + if c.status == disconnected { + c.Unlock() + return nil, errAlreadyDisconnected // May not always be treated as an error + } + if c.status == disconnecting { // Need to wait for existing process to complete + c.willReconnect = false // Ensure that the existing disconnect process will not reconnect + disConnectDone := c.actionCompleted + c.Unlock() + <-disConnectDone // Wait for existing operation to complete + return nil, errAlreadyDisconnected // Well we are now! + } + + prevStatus := c.status + c.status = disconnecting + + // We may need to wait for connection/reconnection process to complete (they should regularly check the status) + if prevStatus == connecting || prevStatus == reconnecting { + connectDone := c.actionCompleted + c.Unlock() // Safe because the only way to leave the disconnecting status is via this function + <-connectDone + + if prevStatus == reconnecting && !c.willReconnect { + return nil, errAlreadyDisconnected // Following connectionLost process we will be disconnected + } + c.Lock() + } + c.actionCompleted = make(chan struct{}) + c.Unlock() + return c.disconnectionCompleted, nil +} + +// disconnectionCompleted is an internal function (it is returned by functions that set the status to disconnecting) +func (c *connectionStatus) disconnectionCompleted() { + c.Lock() + defer c.Unlock() + c.status = disconnected + close(c.actionCompleted) // Alert anything waiting on the connection process to complete + c.actionCompleted = nil +} + +// ConnectionLost - should be called when the connection is lost. +// This really only differs from Disconnecting in that we may transition into a reconnection (but that could be +// cancelled something else calls Disconnecting in the meantime). +// The returned function should be called when cleanup is completed. It will return a function to be called when +// reconnect completes (or nil if no reconnect requested/disconnect called in the interim). +// Note: This function may block if a connection is in progress (the move to connected will be rejected) +func (c *connectionStatus) ConnectionLost(willReconnect bool) (connectionLostHandledFn, error) { + c.Lock() + defer c.Unlock() + if c.status == disconnected { + return nil, errAlreadyDisconnected + } + if c.status == disconnecting { // its expected that connection lost will be called during the disconnection process + return nil, errDisconnectionInProgress + } + + c.willReconnect = willReconnect + prevStatus := c.status + c.status = disconnecting + + // There is a slight possibility that a connection attempt is in progress (connection up and goroutines started but + // status not yet changed). By changing the status we ensure that process will exit cleanly + if prevStatus == connecting || prevStatus == reconnecting { + connectDone := c.actionCompleted + c.Unlock() // Safe because the only way to leave the disconnecting status is via this function + <-connectDone + c.Lock() + if !willReconnect { + // In this case the connection will always be aborted so there is nothing more for us to do + return nil, errAlreadyDisconnected + } + } + c.actionCompleted = make(chan struct{}) + + return c.getConnectionLostHandler(willReconnect), nil +} + +// getConnectionLostHandler is an internal function. It returns the function to be returned by ConnectionLost +func (c *connectionStatus) getConnectionLostHandler(reconnectRequested bool) connectionLostHandledFn { + return func(proceed bool) (connCompletedFn, error) { + // Note that connCompletedFn will only be provided if both reconnectRequested and proceed are true + c.Lock() + defer c.Unlock() + + // `Disconnecting()` may have been called while the disconnection was being processed (this makes it permanent!) + if !c.willReconnect || !proceed { + c.status = disconnected + close(c.actionCompleted) // Alert anything waiting on the connection process to complete + c.actionCompleted = nil + if !reconnectRequested || !proceed { + return nil, nil + } + return nil, errDisconnectionRequested + } + + c.status = reconnecting + return c.connected, nil // Note that c.actionCompleted is still live and will be closed in connected + } +} + +// forceConnectionStatus - forces the connection status to the specified value. +// This should only be used when there is no alternative (i.e. only in tests and to recover from situations that +// are unexpected) +func (c *connectionStatus) forceConnectionStatus(s status) { + c.Lock() + defer c.Unlock() + c.status = s +} diff --git a/unit_client_test.go b/unit_client_test.go index b29575a..9cfa5e1 100644 --- a/unit_client_test.go +++ b/unit_client_test.go @@ -85,7 +85,7 @@ func Test_isConnection(t *testing.T) { ops := NewClientOptions() c := NewClient(ops) - c.(*client).setConnected(connected) + c.(*client).status.forceConnectionStatus(connected) if !c.IsConnectionOpen() { t.Fail() } @@ -95,15 +95,15 @@ func Test_isConnectionOpenNegative(t *testing.T) { ops := NewClientOptions() c := NewClient(ops) - c.(*client).setConnected(reconnecting) + c.(*client).status.forceConnectionStatus(reconnecting) if c.IsConnectionOpen() { t.Fail() } - c.(*client).setConnected(connecting) + c.(*client).status.forceConnectionStatus(connecting) if c.IsConnectionOpen() { t.Fail() } - c.(*client).setConnected(disconnected) + c.(*client).status.forceConnectionStatus(disconnected) if c.IsConnectionOpen() { t.Fail() } diff --git a/unit_messageids_test.go b/unit_messageids_test.go index 84ea781..e3e1fdb 100644 --- a/unit_messageids_test.go +++ b/unit_messageids_test.go @@ -19,7 +19,6 @@ package mqtt import ( - "fmt" "testing" ) @@ -56,8 +55,9 @@ func Test_freeID(t *testing.T) { t.Fatalf("i1 was wrong: %v", i1) } - i2 := mids.getID(&DummyToken{}) - fmt.Printf("i2: %v\n", i2) + // The below may be needed for a specific test but leaving it in permanently makes output confusing + // i2 := mids.getID(&DummyToken{}) + // fmt.Printf("i2: %v\n", i2) } func Test_noFreeID(t *testing.T) { diff --git a/unit_status_test.go b/unit_status_test.go new file mode 100644 index 0000000..e2563ce --- /dev/null +++ b/unit_status_test.go @@ -0,0 +1,411 @@ +/* + * Copyright (c) 2022 IBM Corp and others. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Public License v2.0 + * and Eclipse Distribution License v1.0 which accompany this distribution. + * + * The Eclipse Public License is available at + * https://www.eclipse.org/legal/epl-2.0/ + * and the Eclipse Distribution License is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * Contributors: + * Matt Brittan + */ + +package mqtt + +import ( + "fmt" + "testing" + "time" +) + +func Test_BasicStatusOperations(t *testing.T) { + t.Parallel() + s := connectionStatus{} + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Normal connection and disconnection + cf, err := s.Connecting() + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != connecting { + t.Fatalf("Expected connecting; got: %v", s.ConnectionStatus()) + } + if err = cf(true); err != nil { + t.Fatalf("Error completing connection: %v", err) + } + if s.ConnectionStatus() != connected { + t.Fatalf("Expected connected; got: %v", s.ConnectionStatus()) + } + + // reconnect so we test all statuses + rf, err := s.ConnectionLost(true) + if err != nil { + t.Fatalf("Error calling connection lost: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + if cf, err = rf(true); err != nil { + t.Fatalf("Error completing disconnection portion of reconnect: %v", err) + } + if s.ConnectionStatus() != reconnecting { + t.Fatalf("Expected reconnecting; got: %v", s.ConnectionStatus()) + } + if err = cf(true); err != nil { + t.Fatalf("Error completing reconnection: %v", err) + } + if s.ConnectionStatus() != connected { + t.Fatalf("Expected connected(2); got: %v", s.ConnectionStatus()) + } + + // And disconnect + df, err := s.Disconnecting() + if err != nil { + t.Fatalf("Error disconnecting: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + df() + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } +} + +// Test_AdvancedStatusOperations checks a few of the more unusual transitions +func Test_AdvancedStatusOperations(t *testing.T) { + t.Parallel() + + // Aborted connection (i.e. user triggered) + s := connectionStatus{} + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Normal connection and disconnection + cf, err := s.Connecting() + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != connecting { + t.Fatalf("Expected connecting; got: %v", s.ConnectionStatus()) + } + if err = cf(false); err != nil { // Unsuccessful connection (e.g. user aborted connection) + t.Fatalf("Error completing connection: %v", err) + } + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Connection lost - no reconnection requested + s = connectionStatus{status: connected} + rf, err := s.ConnectionLost(false) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + cf, err = rf(true) // argument should be ignored as no reconnect was requested + if cf != nil { + t.Fatalf("Function to complete reconnection should not be returned (as reconnection not requested)") + } + if err != nil { + t.Fatalf("Error completing connection lost operation: %v", err) + } + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Aborted reconnection - stage 1 (i.e. user triggered whist disconnect in progress) + s = connectionStatus{status: connected} + rf, err = s.ConnectionLost(true) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + cf, err = rf(false) + if cf != nil { + t.Fatalf("Function to complete reconnection should not be returned (as reconnection not requested)") + } + if err != nil { + t.Fatalf("Error completing connection lost operation: %v", err) + } + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Aborted reconnection - stage 2 (i.e. user triggered whist disconnect in progress) + s = connectionStatus{status: connected} + rf, err = s.ConnectionLost(true) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + cf, err = rf(true) + if err != nil { + t.Fatalf("Error completing connection lost operation: %v", err) + } + if cf == nil { + t.Fatalf("Function to complete reconnection should be returned (as reconnection requested)") + } + if s.ConnectionStatus() != reconnecting { + t.Fatalf("Expected reconnecting; got: %v", s.ConnectionStatus()) + } + if err = cf(false); err != nil { + t.Fatalf("Error completing reconnection: %v", err) + } + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } +} + +func Test_AbortedConnection(t *testing.T) { + t.Parallel() + s := connectionStatus{} + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Start Connection + cf, err := s.Connecting() + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != connecting { + t.Fatalf("Expected connecting; got: %v", s.ConnectionStatus()) + } + + // Another goroutine calls Disconnect + discErr := make(chan error) + go func() { + dfFn, err := s.Disconnecting() + discErr <- err + dfFn() + close(discErr) + }() + time.Sleep(time.Millisecond) // Provide time for Disconnect call to run + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + select { + case err = <-discErr: + t.Fatalf("Disconnecting must block until connection attempt terminates: %v", err) + default: + } + + err = cf(true) // status should not matter + if err != errAbortConnection { + t.Fatalf("Expected errAbortConnection got: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + + select { + case err = <-discErr: + if err != nil { + t.Fatalf("Did not expect an error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for goroutine to complete") + } + + time.Sleep(time.Millisecond) // Provide time for other goroutine to complete + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + select { + case <-discErr: // channel should be closed + default: + t.Fatalf("Completion of connect should unblock Disconnecting call") + } +} + +func Test_AbortedReConnection(t *testing.T) { + t.Parallel() + s := connectionStatus{status: connected} // start in connected state + if s.ConnectionStatus() != connected { + t.Fatalf("Expected connected; got: %v", s.ConnectionStatus()) + } + + // Connection is lost but we want to reconnect + lhf, err := s.ConnectionLost(true) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Another goroutine calls Disconnect + discErr := make(chan error) + go func() { + dfFn, err := s.Disconnecting() + if dfFn != nil { + discErr <- fmt.Errorf("should not get a functiuon back from s.Disconnecting in this case") + return + } + discErr <- err + }() + time.Sleep(time.Millisecond) // Provide time for Disconnect call to run + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + select { + case err = <-discErr: + t.Fatalf("Disconnecting must block until reconnection attempt terminates: %v", err) + default: + } + + cf, err := lhf(true) // status should not matter + if cf != nil { + t.Fatalf("As Disconnect has been called we should not have any ability to continue") + } + if err != errDisconnectionRequested { + t.Fatalf("Expected errDisconnectionRequested got: %v", err) + } + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + select { + case err = <-discErr: + if err != errAlreadyDisconnected { + t.Fatalf("Expected errAlreadyDisconnected got: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for goroutine to complete") + } +} + +// Test_ConnectionLostDuringConnect don't really expect this to happen due to connMu +// If it does happen and reconnect is true the results would not be great +func Test_ConnectionLostDuringConnect(t *testing.T) { + t.Parallel() + s := connectionStatus{} + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + + // Start Connection + cf, err := s.Connecting() + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + if s.ConnectionStatus() != connecting { + t.Fatalf("Expected connecting; got: %v", s.ConnectionStatus()) + } + + // Another goroutine calls ConnectionLost (don't expect this to every actually happen but...) + clErr := make(chan error) + go func() { + _, err := s.ConnectionLost(false) + clErr <- err + }() + time.Sleep(time.Millisecond) // Provide time for Disconnect call to run + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + select { + case err = <-clErr: + t.Fatalf("ConnectionLost must block until connection attempt terminates: %v", err) + default: + } + + err = cf(true) // status should not matter + if err != errAbortConnection { + t.Fatalf("Expected errAbortConnection got: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + + select { + case err = <-clErr: + if err != errAlreadyDisconnected { + t.Fatalf("Expected errAlreadyDisconnected got: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for goroutine to complete") + } +} + +/* +clErr := make(chan error) + go func() { + rf, err := s.ConnectionLost(false) + clErr <- err + cf, err := rf(false) + if err != errAlreadyDisconnected { + clErr <- fmt.Errorf("expected errAlreadyDisconnected got %v", err) + } + if cf != nil { + clErr <- fmt.Errorf("cf is not nil") + } + close(clErr) + + }() + time.Sleep(time.Millisecond) // Provide time for Disconnect call to run + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + select { + case err = <-clErr: + t.Fatalf("ConnectionLost must block until connection attempt terminates: %v", err) + default: + } + + err = cf(true) // status should not matter + if err != errAbortConnection { + t.Fatalf("Expected errAbortConnection got: %v", err) + } + if s.ConnectionStatus() != disconnecting { + t.Fatalf("Expected disconnecting; got: %v", s.ConnectionStatus()) + } + + select { + case err = <-clErr: + if err != nil { + t.Fatalf("Did not expect an error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for goroutine to complete") + } + + time.Sleep(time.Millisecond) // Provide time for other goroutine to complete + if s.ConnectionStatus() != disconnected { + t.Fatalf("Expected disconnected; got: %v", s.ConnectionStatus()) + } + select { + case <-clErr: // channel should be closed + default: + t.Fatalf("Completion of connect should unblock Disconnecting call") + } +*/ + +/* +// TODO - Test aborting functions etc + +disconnected -> `Connecting()` -> connecting -> `connCompletedFn(true)` -> connected +connected -> `Disconnecting()` -> disconnecting -> `disconnectCompletedFn()` -> disconnected +connected -> `ConnectionLost(false)` -> disconnecting -> `connectionLostHandledFn(true/false)` -> disconnected +connected -> `ConnectionLost(true)` -> disconnecting -> `connectionLostHandledFn(true)` -> connected + +Unfortunately the above workflows are complicated by the fact that `Disconnecting()` or `ConnectionLost()` may, +potentially, be called at any time (i.e.whilst in the middle of transitioning between states).If this happens: + +* The state will be set to disconnecting (which will prevent any request to move the status to connected) +* The call to `Disconnecting()`/`ConnectionLost()` will block until the previously active call completes and then +handle the disconnection. + +Reading the tests (unit_client_test.go ) might help understand these rules. +*/