diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 81af4f30d48..02b1c3254fa 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -221,6 +221,9 @@ protected void handleNotInUse() { @Nullable final HttpConnectProxiedSocketAddress proxiedAddr; + @VisibleForTesting + int proxySocketTimeout = 30000; + // The following fields should only be used for test. Runnable connectingCallback; SettableFuture connectedFuture; @@ -626,8 +629,8 @@ private void sendConnectionPrefaceAndSettings() { private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, String proxyUsername, String proxyPassword) throws StatusException { + Socket sock = null; try { - Socket sock; // The proxy address may not be resolved if (proxyAddress.getAddress() != null) { sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort()); @@ -636,6 +639,9 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort()); } sock.setTcpNoDelay(true); + // A socket timeout is needed because lost network connectivity while reading from the proxy, + // can cause reading from the socket to hang. + sock.setSoTimeout(proxySocketTimeout); Source source = Okio.source(sock); BufferedSink sink = Okio.buffer(Okio.sink(sock)); @@ -682,8 +688,13 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres statusLine.code, statusLine.message, body.readUtf8()); throw Status.UNAVAILABLE.withDescription(message).asException(); } + // As the socket will be used for RPCs from here on, we want the socket timeout back to zero. + sock.setSoTimeout(0); return sock; } catch (IOException e) { + if (sock != null) { + GrpcUtil.closeQuietly(sock); + } throw Status.UNAVAILABLE.withDescription("Failed trying to connect with proxy").withCause(e) .asException(); } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index e632a6c2946..fcc5e0d2381 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -1877,6 +1877,37 @@ public void proxy_immediateServerClose() throws Exception { verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } + @Test + public void proxy_serverHangs() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + targetAddress, + "authority", + "userAgent", + EAG_ATTRS, + HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress(targetAddress) + .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) + .build(), + tooManyPingsRunnable); + clientTransport.proxySocketTimeout = 10; + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + verify(transportListener, timeout(200)).transportShutdown(any(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + sock.close(); + } + @Test public void goAway_notUtf8() throws Exception { initTransport();