diff --git a/client_test.go b/client_test.go index 0294324..13680cd 100644 --- a/client_test.go +++ b/client_test.go @@ -63,7 +63,12 @@ func TestCustomConnectionFunction(t *testing.T) { client := NewClient(options) // Try to connect using custom function, wait for 2 seconds, to pass MQTT first message - if token := client.Connect(); token.WaitTimeout(2*time.Second) && token.Error() != nil { + // Note that the token should NOT complete (because a CONNACK is never sent) + token := client.Connect() + if token.WaitTimeout(2 * time.Second) { + t.Fatal("token should not complete") // should be blocked waiting for CONNACK + } + if token.Error() != nil { // Should never have an error t.Fatalf("%v", token.Error()) } diff --git a/token.go b/token.go index 996ab5b..9eb122e 100644 --- a/token.go +++ b/token.go @@ -17,6 +17,7 @@ package mqtt import ( + "errors" "sync" "time" @@ -202,3 +203,20 @@ type UnsubscribeToken struct { type DisconnectToken struct { baseToken } + +// TimedOut is the error returned by WaitTimeout when the timeout expires +var TimedOut = errors.New("context canceled") + +// WaitTokenTimeout is a utility function used to simplify the use of token.WaitTimeout +// token.WaitTimeout may return `false` due to time out but t.Error() still results +// in nil. +// `if t := client.X(); t.WaitTimeout(time.Second) && t.Error() != nil {` may evaluate +// to false even if the operation fails. +// It is important to note that if TimedOut is returned, then the operation may still be running +// and could eventually complete successfully. +func WaitTokenTimeout(t Token, d time.Duration) error { + if !t.WaitTimeout(d) { + return TimedOut + } + return t.Error() +} diff --git a/token_test.go b/token_test.go index a175dbd..d80e270 100644 --- a/token_test.go +++ b/token_test.go @@ -40,3 +40,21 @@ func TestWaitTimeout(t *testing.T) { t.Fatal("Should have succeeded") } } + +func TestWaitTokenTimeout(t *testing.T) { + b := baseToken{} + + if !errors.Is(WaitTokenTimeout(&b, time.Second), TimedOut) { + t.Fatal("Should have failed") + } + + // Now let's confirm that WaitTimeout returns correct error + b = baseToken{complete: make(chan struct{})} + testError := errors.New("test") + go func(bt *baseToken) { + bt.setError(testError) + }(&b) + if !errors.Is(WaitTokenTimeout(&b, 5*time.Second), testError) { + t.Fatal("Unexpected error received") + } +}