diff --git a/rpc/websocket.go b/rpc/websocket.go index afeb4c2081b84..5571324af8544 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -37,6 +37,7 @@ const ( wsWriteBuffer = 1024 wsPingInterval = 60 * time.Second wsPingWriteTimeout = 5 * time.Second + wsPongTimeout = 30 * time.Second wsMessageSizeLimit = 15 * 1024 * 1024 ) @@ -241,6 +242,10 @@ type websocketCodec struct { func newWebsocketCodec(conn *websocket.Conn) ServerCodec { conn.SetReadLimit(wsMessageSizeLimit) + conn.SetPongHandler(func(appData string) error { + conn.SetReadDeadline(time.Time{}) + return nil + }) wc := &websocketCodec{ jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), conn: conn, @@ -287,6 +292,7 @@ func (wc *websocketCodec) pingLoop() { wc.jsonCodec.encMu.Lock() wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) wc.conn.WriteMessage(websocket.PingMessage, nil) + wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) wc.jsonCodec.encMu.Unlock() timer.Reset(wsPingInterval) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 4976853baf82b..2486092836bc8 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -18,11 +18,15 @@ package rpc import ( "context" + "io" "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "reflect" "strings" + "sync/atomic" "testing" "time" @@ -188,6 +192,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) { } } +func TestClientWebsocketSevered(t *testing.T) { + t.Parallel() + + var ( + server = wsPingTestServer(t, nil) + ctx = context.Background() + ) + defer server.Shutdown(ctx) + + u, err := url.Parse("http://" + server.Addr) + if err != nil { + t.Fatal(err) + } + rproxy := httputil.NewSingleHostReverseProxy(u) + var severable *severableReadWriteCloser + rproxy.ModifyResponse = func(response *http.Response) error { + severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} + response.Body = severable + return nil + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") + client, err := DialWebsocket(ctx, wsURL, "") + if err != nil { + t.Fatalf("client dial error: %v", err) + } + defer client.Close() + + resultChan := make(chan int) + sub, err := client.EthSubscribe(ctx, resultChan, "foo") + if err != nil { + t.Fatalf("client subscribe error: %v", err) + } + + // sever the connection + severable.Sever() + + // Wait for subscription error. + timeout := time.NewTimer(3 * wsPingInterval) + defer timeout.Stop() + for { + select { + case err := <-sub.Err(): + t.Log("client subscription error:", err) + return + case result := <-resultChan: + t.Error("unexpected result:", result) + return + case <-timeout.C: + t.Error("didn't get any error within the test timeout") + return + } + } +} + // wsPingTestServer runs a WebSocket server which accepts a single subscription request. // When a value arrives on sendPing, the server sends a ping frame, waits for a matching // pong and finally delivers a single subscription result. @@ -290,3 +351,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <- } } } + +// severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. +type severableReadWriteCloser struct { + io.ReadWriteCloser + severed int32 // atomic +} + +func (s *severableReadWriteCloser) Sever() { + atomic.StoreInt32(&s.severed, 1) +} + +func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return 0, nil + } + return s.ReadWriteCloser.Read(p) +} + +func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return len(p), nil + } + return s.ReadWriteCloser.Write(p) +} + +func (s *severableReadWriteCloser) Close() error { + return s.ReadWriteCloser.Close() +}