diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 761300f19fc..6a72eb7c21e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -30,6 +30,7 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.internal.ClientStreamListener.RpcProgress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -238,7 +239,13 @@ public final void shutdownNow(Status status) { } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { - stream.cancel(status); + Runnable runnable = stream.setStream(new FailingClientStream(status, RpcProgress.REFUSED)); + if (runnable != null) { + // Drain in-line instead of using an executor as failing stream just throws everything + // away. This is essentially the same behavior as DelayedStream.cancel() but can be done + // before stream.start(). + runnable.run(); + } } syncContext.execute(savedReportTransportTerminated); } @@ -294,12 +301,10 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (callOptions.getExecutor() != null) { executor = callOptions.getExecutor(); } - executor.execute(new Runnable() { - @Override - public void run() { - stream.createRealStream(transport); - } - }); + Runnable runnable = stream.createRealStream(transport); + if (runnable != null) { + executor.execute(runnable); + } toRemove.add(stream); } // else: stay pending } @@ -346,7 +351,8 @@ private PendingStream(PickSubchannelArgs args) { this.args = args; } - private void createRealStream(ClientTransport transport) { + /** Runnable may be null. */ + private Runnable createRealStream(ClientTransport transport) { ClientStream realStream; Context origContext = context.attach(); try { @@ -355,7 +361,7 @@ private void createRealStream(ClientTransport transport) { } finally { context.detach(origContext); } - setStream(realStream); + return setStream(realStream); } @Override diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index be21b4991ba..7abd517032f 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -29,6 +29,7 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.List; +import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.GuardedBy; /** @@ -59,38 +60,35 @@ class DelayedStream implements ClientStream { private long startTimeNanos; @GuardedBy("this") private long streamSetTimeNanos; + // No need to synchronize; start() synchronization provides a happens-before + private List preStartPendingCalls = new ArrayList<>(); @Override public void setMaxInboundMessageSize(final int maxSize) { - if (passThrough) { - realStream.setMaxInboundMessageSize(maxSize); - } else { - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.setMaxInboundMessageSize(maxSize); - } - }); - } + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { + @Override + public void run() { + realStream.setMaxInboundMessageSize(maxSize); + } + }); } @Override public void setMaxOutboundMessageSize(final int maxSize) { - if (passThrough) { - realStream.setMaxOutboundMessageSize(maxSize); - } else { - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.setMaxOutboundMessageSize(maxSize); - } - }); - } + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { + @Override + public void run() { + realStream.setMaxOutboundMessageSize(maxSize); + } + }); } @Override public void setDeadline(final Deadline deadline) { - delayOrExecute(new Runnable() { + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setDeadline(deadline); @@ -115,21 +113,41 @@ public void appendTimeoutInsight(InsightBuilder insight) { } /** - * Transfers all pending and future requests and mutations to the given stream. + * Transfers all pending and future requests and mutations to the given stream. Method will return + * quickly, but if the returned Runnable is non-null it must be called to complete the process. + * The Runnable may take a while to execute. * *

No-op if either this method or {@link #cancel} have already been called. */ - // When this method returns, passThrough is guaranteed to be true - final void setStream(ClientStream stream) { + // When this method returns, start() has been called on realStream or passThrough is guaranteed to + // be true + @CheckReturnValue + final Runnable setStream(ClientStream stream) { + ClientStreamListener savedListener; synchronized (this) { // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { - return; + return null; } setRealStream(checkNotNull(stream, "stream")); + savedListener = listener; + if (savedListener == null) { + assert pendingCalls.isEmpty(); + pendingCalls = null; + passThrough = true; + } + } + if (savedListener == null) { + return null; + } else { + internalStart(savedListener); + return new Runnable() { + @Override + public void run() { + drainPendingCalls(); + } + }; } - - drainPendingCalls(); } /** @@ -177,6 +195,7 @@ private void drainPendingCalls() { * only if {@code runnable} is thread-safe. */ private void delayOrExecute(Runnable runnable) { + checkState(listener != null, "May only be called after start"); synchronized (this) { if (!passThrough) { pendingCalls.add(runnable); @@ -190,7 +209,7 @@ private void delayOrExecute(Runnable runnable) { public void setAuthority(final String authority) { checkState(listener == null, "May only be called before start"); checkNotNull(authority, "authority"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setAuthority(authority); @@ -200,18 +219,19 @@ public void run() { @Override public void start(ClientStreamListener listener) { + checkNotNull(listener, "listener"); checkState(this.listener == null, "already started"); Status savedError; boolean savedPassThrough; synchronized (this) { - this.listener = checkNotNull(listener, "listener"); // If error != null, then cancel() has been called and was unable to close the listener savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { listener = delayedListener = new DelayedStreamListener(listener); } + this.listener = listener; startTimeNanos = System.nanoTime(); } if (savedError != null) { @@ -220,16 +240,20 @@ public void start(ClientStreamListener listener) { } if (savedPassThrough) { - realStream.start(listener); - } else { - final ClientStreamListener finalListener = listener; - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.start(finalListener); - } - }); + internalStart(listener); + } // else internalStart() will be called by setStream + } + + /** + * Starts stream without synchronization. {@code listener} should be same instance as {@link + * #listener}. + */ + private void internalStart(ClientStreamListener listener) { + for (Runnable runnable : preStartPendingCalls) { + runnable.run(); } + preStartPendingCalls = null; + realStream.start(listener); } @Override @@ -247,6 +271,7 @@ public Attributes getAttributes() { @Override public void writeMessage(final InputStream message) { + checkState(listener != null, "May only be called after start"); checkNotNull(message, "message"); if (passThrough) { realStream.writeMessage(message); @@ -262,6 +287,7 @@ public void run() { @Override public void flush() { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.flush(); } else { @@ -277,16 +303,14 @@ public void run() { // When this method returns, passThrough is guaranteed to be true @Override public void cancel(final Status reason) { + checkState(listener != null, "May only be called after start"); checkNotNull(reason, "reason"); boolean delegateToRealStream = true; - ClientStreamListener listenerToClose = null; synchronized (this) { // If realStream != null, then either setStream() or cancel() has been called if (realStream == null) { setRealStream(NoopClientStream.INSTANCE); delegateToRealStream = false; - // If listener == null, then start() will later call listener with 'error' - listenerToClose = listener; error = reason; } } @@ -298,10 +322,9 @@ public void run() { } }); } else { - if (listenerToClose != null) { - listenerToClose.closed(reason, new Metadata()); - } drainPendingCalls(); + // Note that listener is a DelayedStreamListener + listener.closed(reason, new Metadata()); } } @@ -314,6 +337,7 @@ private void setRealStream(ClientStream realStream) { @Override public void halfClose() { + checkState(listener != null, "May only be called after start"); delayOrExecute(new Runnable() { @Override public void run() { @@ -324,6 +348,7 @@ public void run() { @Override public void request(final int numMessages) { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.request(numMessages); } else { @@ -338,7 +363,8 @@ public void run() { @Override public void optimizeForDirectExecutor() { - delayOrExecute(new Runnable() { + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.optimizeForDirectExecutor(); @@ -348,8 +374,9 @@ public void run() { @Override public void setCompressor(final Compressor compressor) { + checkState(listener == null, "May only be called before start"); checkNotNull(compressor, "compressor"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setCompressor(compressor); @@ -359,7 +386,8 @@ public void run() { @Override public void setFullStreamDecompression(final boolean fullStreamDecompression) { - delayOrExecute( + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add( new Runnable() { @Override public void run() { @@ -370,8 +398,9 @@ public void run() { @Override public void setDecompressorRegistry(final DecompressorRegistry decompressorRegistry) { + checkState(listener == null, "May only be called before start"); checkNotNull(decompressorRegistry, "decompressorRegistry"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setDecompressorRegistry(decompressorRegistry); @@ -390,6 +419,7 @@ public boolean isReady() { @Override public void setMessageCompression(final boolean enable) { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.setMessageCompression(enable); } else { diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index c3196ddd107..4c49a14a06b 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -95,7 +95,11 @@ private void finalizeWith(ClientStream stream) { // returnStream() has been called before me, thus delayedStream must have been // created. checkState(delayedStream != null, "delayedStream is null"); - delayedStream.setStream(stream); + Runnable slow = delayedStream.setStream(stream); + if (slow != null) { + // TODO(ejona): run this on a separate thread + slow.run(); + } } /** diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index f01b9d4d6fe..89655a7d7b5 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; @@ -161,7 +162,7 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - assertEquals(1, fakeExecutor.runDueTasks()); + assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); @@ -199,15 +200,6 @@ public void uncaughtException(Thread t, Throwable e) { any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); } - @Test public void cancelStreamWithoutSetTransport() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); - assertEquals(1, delayedTransport.getPendingStreamsCount()); - stream.cancel(Status.CANCELLED); - assertEquals(0, delayedTransport.getPendingStreamsCount()); - verifyNoMoreInteractions(mockRealTransport); - verifyNoMoreInteractions(mockRealStream); - } - @Test public void startThenCancelStreamWithoutSetTransport() { ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); stream.start(streamListener); @@ -258,6 +250,7 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); + stream.start(streamListener); stream.cancel(Status.CANCELLED); verify(transportListener).transportTerminated(); assertEquals(0, delayedTransport.getPendingStreamsCount()); @@ -282,7 +275,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class)); + verify(streamListener) + .closed(statusCaptor.capture(), eq(RpcProgress.REFUSED), any(Metadata.class)); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } @@ -316,6 +310,8 @@ public void uncaughtException(Thread t, Throwable e) { // Fail-fast streams DelayedStream ff1 = (DelayedStream) delayedTransport.newStream( method, headers, failFastCallOptions); + ff1.start(mock(ClientStreamListener.class)); + ff1.halfClose(); PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); @@ -346,6 +342,8 @@ public void uncaughtException(Thread t, Throwable e) { wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( method, headers, wfr3callOptions); + wfr3.start(mock(ClientStreamListener.class)); + wfr3.halfClose(); PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( @@ -381,18 +379,22 @@ public void uncaughtException(Thread t, Throwable e) { inOrder.verify(picker).pickSubchannel(wfr4args); inOrder.verifyNoMoreInteractions(); - // Make sure that real transport creates streams in the executor - verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - verify(mockRealTransport2, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - fakeExecutor.runDueTasks(); - assertEquals(0, fakeExecutor.numPendingTasks()); + // Make sure that streams are created and started immediately, not in any executor. This is + // necessary during shut down to guarantee that when DelayedClientTransport terminates, all + // streams are now owned by a real transport (which should prevent the Channel from + // terminating). // ff1 and wfr1 went through verify(mockRealTransport).newStream(method, headers, failFastCallOptions); verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); assertSame(mockRealStream, ff1.getRealStream()); assertSame(mockRealStream2, wfr1.getRealStream()); + verify(mockRealStream).start(any(ClientStreamListener.class)); + // But also verify that non-start()-related calls are run within the Executor, since they may be + // slow. + verify(mockRealStream, never()).halfClose(); + fakeExecutor.runDueTasks(); + assertEquals(0, fakeExecutor.numPendingTasks()); + verify(mockRealStream).halfClose(); // The ff2 has failed due to picker returning an error assertSame(Status.UNAVAILABLE, ((FailingClientStream) ff2.getRealStream()).getError()); // Other streams are still buffered @@ -431,10 +433,11 @@ public void uncaughtException(Thread t, Throwable e) { assertSame(mockRealStream2, wfr2.getRealStream()); assertSame(mockRealStream2, wfr4.getRealStream()); + assertSame(mockRealStream, wfr3.getRealStream()); // If there is an executor in the CallOptions, it will be used to create the real stream. - assertNull(wfr3.getRealStream()); + verify(mockRealStream, times(1)).halfClose(); // 1 for ff1 wfr3Executor.runDueTasks(); - assertSame(mockRealStream, wfr3.getRealStream()); + verify(mockRealStream, times(2)).halfClose(); // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 393a6c6e6d0..f3126ba6f42 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -72,7 +72,7 @@ public void setStream_setAuthority() { final String authority = "becauseIsaidSo"; stream.setAuthority(authority); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); InOrder inOrder = inOrder(realStream); inOrder.verify(realStream).setAuthority(authority); inOrder.verify(realStream).start(any(ClientStreamListener.class)); @@ -92,9 +92,9 @@ public void start_afterStart() { @Test public void setStream_sendsAllMessages() { - stream.start(listener); stream.setCompressor(Codec.Identity.NONE); stream.setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); + stream.start(listener); stream.setMessageCompression(true); InputStream message = new ByteArrayInputStream(new byte[]{'a'}); @@ -102,7 +102,7 @@ public void setStream_sendsAllMessages() { stream.setMessageCompression(false); stream.writeMessage(message); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).setCompressor(Codec.Identity.NONE); verify(realStream).setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); @@ -125,7 +125,7 @@ public void setStream_sendsAllMessages() { public void setStream_halfClose() { stream.start(listener); stream.halfClose(); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).halfClose(); } @@ -134,7 +134,7 @@ public void setStream_halfClose() { public void setStream_flush() { stream.start(listener); stream.flush(); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).flush(); stream.flush(); @@ -146,7 +146,7 @@ public void setStream_flowControl() { stream.start(listener); stream.request(1); stream.request(2); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).request(1); verify(realStream).request(2); @@ -158,7 +158,7 @@ public void setStream_flowControl() { public void setStream_setMessageCompression() { stream.start(listener); stream.setMessageCompression(false); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).setMessageCompression(false); stream.setMessageCompression(true); @@ -169,7 +169,7 @@ public void setStream_setMessageCompression() { public void setStream_isReady() { stream.start(listener); assertFalse(stream.isReady()); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream, never()).isReady(); assertFalse(stream.isReady()); @@ -190,7 +190,7 @@ public void setStream_getAttributes() { assertEquals(Attributes.EMPTY, stream.getAttributes()); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); assertEquals(attributes, stream.getAttributes()); } @@ -204,7 +204,7 @@ public void startThenCancelled() { @Test public void startThenSetStreamThenCancelled() { stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); stream.cancel(Status.CANCELLED); verify(realStream).start(any(ClientStreamListener.class)); verify(realStream).cancel(same(Status.CANCELLED)); @@ -212,52 +212,36 @@ public void startThenSetStreamThenCancelled() { @Test public void setStreamThenStartThenCancelled() { - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); stream.start(listener); stream.cancel(Status.CANCELLED); verify(realStream).start(same(listener)); verify(realStream).cancel(same(Status.CANCELLED)); } - @Test - public void setStreamThenCancelled() { - stream.setStream(realStream); - stream.cancel(Status.CANCELLED); - verify(realStream).cancel(same(Status.CANCELLED)); - } - @Test public void setStreamTwice() { stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - stream.setStream(mock(ClientStream.class)); + callMeMaybe(stream.setStream(mock(ClientStream.class))); stream.flush(); verify(realStream).flush(); } @Test public void cancelThenSetStream() { - stream.cancel(Status.CANCELLED); - stream.setStream(realStream); stream.start(listener); + stream.cancel(Status.CANCELLED); + callMeMaybe(stream.setStream(realStream)); stream.isReady(); verifyNoMoreInteractions(realStream); } - @Test + @Test(expected = IllegalStateException.class) public void cancel_beforeStart() { Status status = Status.CANCELLED.withDescription("that was quick"); stream.cancel(status); - stream.start(listener); - verify(listener).closed(same(status), any(Metadata.class)); - } - - @Test - public void cancelledThenStart() { - stream.cancel(Status.CANCELLED); - stream.start(listener); - verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class)); } @Test @@ -275,7 +259,7 @@ public void onReady() { IsReadyListener isReadyListener = new IsReadyListener(); stream.start(isReadyListener); - stream.setStream(new NoopClientStream() { + callMeMaybe(stream.setStream(new NoopClientStream() { @Override public void start(ClientStreamListener listener) { // This call to the listener should end up being delayed. @@ -286,7 +270,7 @@ public void start(ClientStreamListener listener) { public boolean isReady() { return true; } - }); + })); assertTrue(isReadyListener.onReadyCalled); } @@ -302,7 +286,7 @@ public void listener_allQueued() { final InOrder inOrder = inOrder(listener); stream.start(listener); - stream.setStream(new NoopClientStream() { + callMeMaybe(stream.setStream(new NoopClientStream() { @Override public void start(ClientStreamListener passedListener) { passedListener.onReady(); @@ -314,7 +298,7 @@ public void start(ClientStreamListener passedListener) { verifyNoMoreInteractions(listener); } - }); + })); inOrder.verify(listener).onReady(); inOrder.verify(listener).headersRead(headers); inOrder.verify(listener).messagesAvailable(producer1); @@ -332,7 +316,7 @@ public void listener_noQueued() { final Status status = Status.UNKNOWN.withDescription("unique status"); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).start(listenerCaptor.capture()); ClientStreamListener delayedListener = listenerCaptor.getValue(); delayedListener.onReady(); @@ -371,11 +355,17 @@ public Void answer(InvocationOnMock in) { } }).when(realStream).appendTimeoutInsight(any(InsightBuilder.class)); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) .matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]"); } + + private void callMeMaybe(Runnable r) { + if (r != null) { + r.run(); + } + } }