From 0755c7a1f5502f44cdbdf7772f0c5035179b7682 Mon Sep 17 00:00:00 2001 From: jmank88 Date: Wed, 8 Sep 2021 13:11:11 -0500 Subject: [PATCH] rpc: add TestClientWebsocketSevered --- rpc/websocket_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 4976853baf82b..2486092836bc8 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -18,11 +18,15 @@ package rpc import ( "context" + "io" "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "reflect" "strings" + "sync/atomic" "testing" "time" @@ -188,6 +192,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) { } } +func TestClientWebsocketSevered(t *testing.T) { + t.Parallel() + + var ( + server = wsPingTestServer(t, nil) + ctx = context.Background() + ) + defer server.Shutdown(ctx) + + u, err := url.Parse("http://" + server.Addr) + if err != nil { + t.Fatal(err) + } + rproxy := httputil.NewSingleHostReverseProxy(u) + var severable *severableReadWriteCloser + rproxy.ModifyResponse = func(response *http.Response) error { + severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} + response.Body = severable + return nil + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") + client, err := DialWebsocket(ctx, wsURL, "") + if err != nil { + t.Fatalf("client dial error: %v", err) + } + defer client.Close() + + resultChan := make(chan int) + sub, err := client.EthSubscribe(ctx, resultChan, "foo") + if err != nil { + t.Fatalf("client subscribe error: %v", err) + } + + // sever the connection + severable.Sever() + + // Wait for subscription error. + timeout := time.NewTimer(3 * wsPingInterval) + defer timeout.Stop() + for { + select { + case err := <-sub.Err(): + t.Log("client subscription error:", err) + return + case result := <-resultChan: + t.Error("unexpected result:", result) + return + case <-timeout.C: + t.Error("didn't get any error within the test timeout") + return + } + } +} + // wsPingTestServer runs a WebSocket server which accepts a single subscription request. // When a value arrives on sendPing, the server sends a ping frame, waits for a matching // pong and finally delivers a single subscription result. @@ -290,3 +351,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <- } } } + +// severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. +type severableReadWriteCloser struct { + io.ReadWriteCloser + severed int32 // atomic +} + +func (s *severableReadWriteCloser) Sever() { + atomic.StoreInt32(&s.severed, 1) +} + +func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return 0, nil + } + return s.ReadWriteCloser.Read(p) +} + +func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return len(p), nil + } + return s.ReadWriteCloser.Write(p) +} + +func (s *severableReadWriteCloser) Close() error { + return s.ReadWriteCloser.Close() +}