diff --git a/client_test.go b/client_test.go index 40e514a..0294324 100644 --- a/client_test.go +++ b/client_test.go @@ -31,15 +31,27 @@ func TestCustomConnectionFunction(t *testing.T) { netClient, netServer := net.Pipe() defer netClient.Close() defer netServer.Close() - var firstMessage = "" + + outputChan := make(chan struct { + msg []byte + err error + }) go func() { // read first message only bytes := make([]byte, 1024) + netServer.SetDeadline(time.Now().Add(time.Second)) // Ensure this will always complete n, err := netServer.Read(bytes) if err != nil { - t.Errorf("%v", err) + outputChan <- struct { + msg []byte + err error + }{err: err} + } else { + outputChan <- struct { + msg []byte + err error + }{msg: bytes[:n]} } - firstMessage = string(bytes[:n]) }() // Set custom network connection function and client connect var customConnectionFunc OpenConnectionFunc = func(uri *url.URL, options ClientOptions) (net.Conn, error) { @@ -52,10 +64,16 @@ func TestCustomConnectionFunction(t *testing.T) { // Try to connect using custom function, wait for 2 seconds, to pass MQTT first message if token := client.Connect(); token.WaitTimeout(2*time.Second) && token.Error() != nil { - t.Errorf("%v", token.Error()) + t.Fatalf("%v", token.Error()) + } + + msg := <-outputChan + if msg.err != nil { + t.Fatalf("read from simulated connection failed: %v", msg.err) } // Analyze first message sent by client and received by the server + firstMessage := string(msg.msg) if len(firstMessage) <= 0 || !strings.Contains(firstMessage, "MQTT") { t.Error("no message received on connect") }