From cea47553738d86b69708479e35d42f4c91cb21d7 Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Tue, 4 Oct 2022 10:56:17 -0700 Subject: [PATCH] okhttp: Add client transport proxy socket timeout Not having a timeout when reading the response from a proxy server can cause a hang if network connectivity is lost at the same time. --- .../io/grpc/okhttp/OkHttpClientTransport.java | 8 +++++ .../okhttp/OkHttpClientTransportTest.java | 34 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 81af4f30d480..9007ca1a5ced 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -221,6 +221,9 @@ protected void handleNotInUse() { @Nullable final HttpConnectProxiedSocketAddress proxiedAddr; + @VisibleForTesting + int proxySocketTimeout = 30000; + // The following fields should only be used for test. Runnable connectingCallback; SettableFuture connectedFuture; @@ -636,6 +639,9 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort()); } sock.setTcpNoDelay(true); + // A socket timeout is needed because lost network connectivity while reading from the proxy, + // can cause reading from the socket to hang. + sock.setSoTimeout(proxySocketTimeout); Source source = Okio.source(sock); BufferedSink sink = Okio.buffer(Okio.sink(sock)); @@ -682,6 +688,8 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres statusLine.code, statusLine.message, body.readUtf8()); throw Status.UNAVAILABLE.withDescription(message).asException(); } + // As the socket will be used for RPCs from here on, we want the socket timeout back to zero. + sock.setSoTimeout(0); return sock; } catch (IOException e) { throw Status.UNAVAILABLE.withDescription("Failed trying to connect with proxy").withCause(e) diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index e632a6c2946a..0c536dc47ff2 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -154,7 +154,7 @@ public class OkHttpClientTransportTest { new ClientStreamTracer() {} }; - @Rule public final Timeout globalTimeout = Timeout.seconds(10); + @Rule public final Timeout globalTimeout = Timeout.seconds(100000000); private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @@ -1877,6 +1877,38 @@ public void proxy_immediateServerClose() throws Exception { verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } + @Test + public void proxy_serverHangs() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + targetAddress, + "authority", + "userAgent", + EAG_ATTRS, + HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress(targetAddress) + .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) + .build(), + tooManyPingsRunnable); + clientTransport.proxySocketTimeout = 10; + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(15)).transportShutdown(captor.capture()); + sock.close(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + @Test public void goAway_notUtf8() throws Exception { initTransport();