From f3385997a19ce36d5324a19581d8a171f2cb8e27 Mon Sep 17 00:00:00 2001 From: larry-safran Date: Mon, 27 Jun 2022 14:21:29 -0700 Subject: [PATCH] [core] Use SyncContext for InProcessTransport listener callbacks to avoid deadlocks (Fixes bug #3084) Also support unary calls returning null values --- .../io/grpc/inprocess/InProcessTransport.java | 284 +++++++++++------- .../helloworld/HelloWorldClientTest.java | 2 - .../helloworld/HelloWorldServerTest.java | 2 - .../routeguide/RouteGuideClientTest.java | 2 - .../routeguide/RouteGuideServerTest.java | 2 - .../main/java/io/grpc/stub/ClientCalls.java | 6 +- 6 files changed, 173 insertions(+), 125 deletions(-) diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index e40658d08ff4..41ab994ffa22 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -40,6 +40,7 @@ import io.grpc.SecurityLevel; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -407,16 +408,14 @@ private void streamClosed() { private class InProcessServerStream implements ServerStream { final StatsTraceContext statsTraceCtx; - @GuardedBy("this") private ClientStreamListener clientStreamListener; + private final SynchronizationContext syncContext; @GuardedBy("this") private int clientRequested; @GuardedBy("this") private ArrayDeque clientReceiveQueue = new ArrayDeque<>(); - @GuardedBy("this") private Status clientNotifyStatus; - @GuardedBy("this") private Metadata clientNotifyTrailers; // Only is intended to prevent double-close when client cancels. @GuardedBy("this") @@ -426,7 +425,14 @@ private class InProcessServerStream implements ServerStream { InProcessServerStream(MethodDescriptor method, Metadata headers) { statsTraceCtx = StatsTraceContext.newServerContext( - serverStreamTracerFactories, method.getFullMethodName(), headers); + serverStreamTracerFactories, method.getFullMethodName(), headers); + + syncContext = new SynchronizationContext(new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new RuntimeException(e); + } + }); } private synchronized void setListener(ClientStreamListener listener) { @@ -442,42 +448,51 @@ public void setListener(ServerStreamListener serverStreamListener) { public void request(int numMessages) { boolean onReady = clientStream.serverRequested(numMessages); if (onReady) { - synchronized (this) { + synchronized (this) { // TODO How should this be handled (probably just remove) if (!closed) { - clientStreamListener.onReady(); + syncContext.executeLater(() -> clientStreamListener.onReady()); } } } + syncContext.drain(); } // This method is the only reason we have to synchronize field accesses. + // Using a SynchronizationContext to avoid possibility of deadlock in direct executors. /** * Client requested more messages. * * @return whether onReady should be called on the server */ - private synchronized boolean clientRequested(int numMessages) { - if (closed) { - return false; - } - boolean previouslyReady = clientRequested > 0; - clientRequested += numMessages; - while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) { - clientRequested--; - clientStreamListener.messagesAvailable(clientReceiveQueue.poll()); - } - // Attempt being reentrant-safe - if (closed) { - return false; - } - if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) { - closed = true; - clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers); - clientStream.statsTraceCtx.streamClosed(clientNotifyStatus); - clientStreamListener.closed( - clientNotifyStatus, RpcProgress.PROCESSED, clientNotifyTrailers); + private boolean clientRequested(int numMessages) { + boolean previouslyReady; + boolean nowReady; + synchronized (this) { + if (closed) { + return false; + } + + previouslyReady = clientRequested > 0; + clientRequested += numMessages; + while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) { + clientRequested--; + StreamListener.MessageProducer producer = clientReceiveQueue.poll(); + syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer)); + } + + if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) { + closed = true; + clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers); + clientStream.statsTraceCtx.streamClosed(clientNotifyStatus); + syncContext.executeLater(() -> + clientStreamListener.closed( + clientNotifyStatus, RpcProgress.PROCESSED, clientNotifyTrailers)); + } + + nowReady = clientRequested > 0; } - boolean nowReady = clientRequested > 0; + + syncContext.drain(); return !previouslyReady && nowReady; } @@ -485,23 +500,28 @@ private void clientCancelled(Status status) { internalCancel(status); } + // Using syncContext to avoid possibility of deadlock @Override - public synchronized void writeMessage(InputStream message) { - if (closed) { - return; - } - statsTraceCtx.outboundMessage(outboundSeqNo); - statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); - clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); - clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); - outboundSeqNo++; - StreamListener.MessageProducer producer = new SingleMessageProducer(message); - if (clientRequested > 0) { - clientRequested--; - clientStreamListener.messagesAvailable(producer); - } else { - clientReceiveQueue.add(producer); + public void writeMessage(InputStream message) { + synchronized (this) { + if (closed) { + return; + } + statsTraceCtx.outboundMessage(outboundSeqNo); + statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); + clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); + clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + outboundSeqNo++; + StreamListener.MessageProducer producer = new SingleMessageProducer(message); + if (clientRequested > 0) { + clientRequested--; + syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer)); + } else { + clientReceiveQueue.add(producer); + } } + + syncContext.drain(); } @Override @@ -540,8 +560,9 @@ public void writeHeaders(Metadata headers) { } clientStream.statsTraceCtx.clientInboundHeaders(); - clientStreamListener.headersRead(headers); + syncContext.executeLater(() -> clientStreamListener.headersRead(headers)); } + syncContext.drain(); } @Override @@ -585,13 +606,14 @@ private void notifyClientClose(Status status, Metadata trailers) { closed = true; clientStream.statsTraceCtx.clientInboundTrailers(trailers); clientStream.statsTraceCtx.streamClosed(clientStatus); - clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, trailers); + syncContext.executeLater( + () -> clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, trailers)); } else { clientNotifyStatus = clientStatus; clientNotifyTrailers = trailers; } } - + syncContext.drain(); streamClosed(); } @@ -604,24 +626,29 @@ public void cancel(Status status) { streamClosed(); } - private synchronized boolean internalCancel(Status clientStatus) { - if (closed) { - return false; - } - closed = true; - StreamListener.MessageProducer producer; - while ((producer = clientReceiveQueue.poll()) != null) { - InputStream message; - while ((message = producer.next()) != null) { - try { - message.close(); - } catch (Throwable t) { - log.log(Level.WARNING, "Exception closing stream", t); + private boolean internalCancel(Status clientStatus) { + synchronized (this) { + if (closed) { + return false; + } + closed = true; + StreamListener.MessageProducer producer; + while ((producer = clientReceiveQueue.poll()) != null) { + InputStream message; + while ((message = producer.next()) != null) { + try { + message.close(); + } catch (Throwable t) { + log.log(Level.WARNING, "Exception closing stream", t); + } } } + clientStream.statsTraceCtx.streamClosed(clientStatus); + syncContext.executeLater( + () -> + clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, new Metadata())); } - clientStream.statsTraceCtx.streamClosed(clientStatus); - clientStreamListener.closed(clientStatus, RpcProgress.PROCESSED, new Metadata()); + syncContext.drain(); return true; } @@ -662,8 +689,8 @@ public int streamId() { private class InProcessClientStream implements ClientStream { final StatsTraceContext statsTraceCtx; final CallOptions callOptions; - @GuardedBy("this") private ServerStreamListener serverStreamListener; + private final SynchronizationContext syncContext; @GuardedBy("this") private int serverRequested; @GuardedBy("this") @@ -681,6 +708,15 @@ private class InProcessClientStream implements ClientStream { CallOptions callOptions, StatsTraceContext statsTraceContext) { this.callOptions = callOptions; statsTraceCtx = statsTraceContext; + + + syncContext = new SynchronizationContext(new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new RuntimeException(e); + } + }); + } private synchronized void setListener(ServerStreamListener listener) { @@ -693,9 +729,10 @@ public void request(int numMessages) { if (onReady) { synchronized (this) { if (!closed) { - serverStreamListener.onReady(); + syncContext.executeLater(() -> serverStreamListener.onReady()); } } + syncContext.drain(); } } @@ -705,21 +742,29 @@ public void request(int numMessages) { * * @return whether onReady should be called on the server */ - private synchronized boolean serverRequested(int numMessages) { - if (closed) { - return false; - } - boolean previouslyReady = serverRequested > 0; - serverRequested += numMessages; - while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) { - serverRequested--; - serverStreamListener.messagesAvailable(serverReceiveQueue.poll()); - } - if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) { - serverNotifyHalfClose = false; - serverStreamListener.halfClosed(); + private boolean serverRequested(int numMessages) { + boolean previouslyReady; + boolean nowReady; + synchronized (this) { + if (closed) { + return false; + } + previouslyReady = serverRequested > 0; + serverRequested += numMessages; + + while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) { + serverRequested--; + StreamListener.MessageProducer producer = serverReceiveQueue.poll(); + syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer)); + } + + if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) { + serverNotifyHalfClose = false; + syncContext.executeLater(() -> serverStreamListener.halfClosed()); + } + nowReady = serverRequested > 0; } - boolean nowReady = serverRequested > 0; + syncContext.drain(); return !previouslyReady && nowReady; } @@ -728,22 +773,25 @@ private void serverClosed(Status serverListenerStatus, Status serverTracerStatus } @Override - public synchronized void writeMessage(InputStream message) { - if (closed) { - return; - } - statsTraceCtx.outboundMessage(outboundSeqNo); - statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); - serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); - serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); - outboundSeqNo++; - StreamListener.MessageProducer producer = new SingleMessageProducer(message); - if (serverRequested > 0) { - serverRequested--; - serverStreamListener.messagesAvailable(producer); - } else { - serverReceiveQueue.add(producer); + public void writeMessage(InputStream message) { + synchronized (this) { + if (closed) { + return; + } + statsTraceCtx.outboundMessage(outboundSeqNo); + statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); + serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); + serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + outboundSeqNo++; + StreamListener.MessageProducer producer = new SingleMessageProducer(message); + if (serverRequested > 0) { + serverRequested--; + syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer)); + } else { + serverReceiveQueue.add(producer); + } } + syncContext.drain(); } @Override @@ -768,39 +816,45 @@ public void cancel(Status reason) { streamClosed(); } - private synchronized boolean internalCancel( + private boolean internalCancel( Status serverListenerStatus, Status serverTracerStatus) { - if (closed) { - return false; - } - closed = true; - - StreamListener.MessageProducer producer; - while ((producer = serverReceiveQueue.poll()) != null) { - InputStream message; - while ((message = producer.next()) != null) { - try { - message.close(); - } catch (Throwable t) { - log.log(Level.WARNING, "Exception closing stream", t); + synchronized (this) { + if (closed) { + return false; + } + closed = true; + + StreamListener.MessageProducer producer; + while ((producer = serverReceiveQueue.poll()) != null) { + InputStream message; + while ((message = producer.next()) != null) { + try { + message.close(); + } catch (Throwable t) { + log.log(Level.WARNING, "Exception closing stream", t); + } } } + serverStream.statsTraceCtx.streamClosed(serverTracerStatus); + syncContext.executeLater(() -> serverStreamListener.closed(serverListenerStatus)); } - serverStream.statsTraceCtx.streamClosed(serverTracerStatus); - serverStreamListener.closed(serverListenerStatus); + syncContext.drain(); return true; } @Override - public synchronized void halfClose() { - if (closed) { - return; - } - if (serverReceiveQueue.isEmpty()) { - serverStreamListener.halfClosed(); - } else { - serverNotifyHalfClose = true; + public void halfClose() { + synchronized (this) { + if (closed) { + return; + } + if (serverReceiveQueue.isEmpty()) { + syncContext.executeLater(() -> serverStreamListener.halfClosed()); + } else { + serverNotifyHalfClose = true; + } } + syncContext.drain(); } @Override diff --git a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java index d0262d687eee..8c6cf60279a4 100644 --- a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java +++ b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldClientTest.java @@ -40,8 +40,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and * {@link io.grpc.examples.routeguide.RouteGuideServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java index 9a20476772c8..63281eeba1a5 100644 --- a/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java +++ b/examples/src/test/java/io/grpc/examples/helloworld/HelloWorldServerTest.java @@ -33,8 +33,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and * {@link io.grpc.examples.routeguide.RouteGuideServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java index be2337cc4ce9..4c184fb82eea 100644 --- a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java +++ b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideClientTest.java @@ -53,8 +53,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and * {@link io.grpc.examples.helloworld.HelloWorldServerTest}. diff --git a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java index 19322c2d72c9..a5a84824af63 100644 --- a/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java +++ b/examples/src/test/java/io/grpc/examples/routeguide/RouteGuideServerTest.java @@ -50,8 +50,6 @@ * Not intended to provide a high code coverage or to test every major usecase. * * directExecutor() makes it easier to have deterministic tests. - * However, if your implementation uses another thread and uses streaming it is better to use - * the default executor, to avoid hitting bug #3084. * *

For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and * {@link io.grpc.examples.helloworld.HelloWorldServerTest}. diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 4fac94bfaefa..6986a285ae27 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -509,6 +509,7 @@ void onStart() { private static final class UnaryStreamToFuture extends StartableListener { private final GrpcFuture responseFuture; private RespT value; + private boolean isValueReceived = false; // Non private to avoid synthetic class UnaryStreamToFuture(GrpcFuture responseFuture) { @@ -521,17 +522,18 @@ public void onHeaders(Metadata headers) { @Override public void onMessage(RespT value) { - if (this.value != null) { + if (this.isValueReceived) { throw Status.INTERNAL.withDescription("More than one value received for unary call") .asRuntimeException(); } this.value = value; + this.isValueReceived = true; } @Override public void onClose(Status status, Metadata trailers) { if (status.isOk()) { - if (value == null) { + if (!isValueReceived) { // No value received so mark the future as an error responseFuture.setException( Status.INTERNAL.withDescription("No value received for unary call")