From 7b8105e100c9d4c4697515030085554d4d332169 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 20 Jan 2021 15:01:49 -0800 Subject: [PATCH] core: DelayedStream should start() real stream immediately DelayedClientTransport needs to avoid becoming terminated while it owns RPCs. Previously DelayedClientTransport could terminate when some of its RPCs had their realStream but realStream.start() hadn't yet been called. To avoid that, we now make sure to call realStream.start() synchronously with setting realStream. Since start() and the method calls before start execute quickly, we can run it in-line. But it does mean we now need to split the Stream methods into "before start" and "after start" categories for queuing. Fixes #6283 --- .../grpc/internal/DelayedClientTransport.java | 24 ++-- .../java/io/grpc/internal/DelayedStream.java | 126 +++++++++++------- .../io/grpc/internal/MetadataApplierImpl.java | 6 +- .../internal/DelayedClientTransportTest.java | 43 +++--- .../io/grpc/internal/DelayedStreamTest.java | 66 ++++----- 5 files changed, 149 insertions(+), 116 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 761300f19fc..6a72eb7c21e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -30,6 +30,7 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.internal.ClientStreamListener.RpcProgress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -238,7 +239,13 @@ public final void shutdownNow(Status status) { } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { - stream.cancel(status); + Runnable runnable = stream.setStream(new FailingClientStream(status, RpcProgress.REFUSED)); + if (runnable != null) { + // Drain in-line instead of using an executor as failing stream just throws everything + // away. This is essentially the same behavior as DelayedStream.cancel() but can be done + // before stream.start(). + runnable.run(); + } } syncContext.execute(savedReportTransportTerminated); } @@ -294,12 +301,10 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (callOptions.getExecutor() != null) { executor = callOptions.getExecutor(); } - executor.execute(new Runnable() { - @Override - public void run() { - stream.createRealStream(transport); - } - }); + Runnable runnable = stream.createRealStream(transport); + if (runnable != null) { + executor.execute(runnable); + } toRemove.add(stream); } // else: stay pending } @@ -346,7 +351,8 @@ private PendingStream(PickSubchannelArgs args) { this.args = args; } - private void createRealStream(ClientTransport transport) { + /** Runnable may be null. */ + private Runnable createRealStream(ClientTransport transport) { ClientStream realStream; Context origContext = context.attach(); try { @@ -355,7 +361,7 @@ private void createRealStream(ClientTransport transport) { } finally { context.detach(origContext); } - setStream(realStream); + return setStream(realStream); } @Override diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index be21b4991ba..7abd517032f 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -29,6 +29,7 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.List; +import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.GuardedBy; /** @@ -59,38 +60,35 @@ class DelayedStream implements ClientStream { private long startTimeNanos; @GuardedBy("this") private long streamSetTimeNanos; + // No need to synchronize; start() synchronization provides a happens-before + private List preStartPendingCalls = new ArrayList<>(); @Override public void setMaxInboundMessageSize(final int maxSize) { - if (passThrough) { - realStream.setMaxInboundMessageSize(maxSize); - } else { - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.setMaxInboundMessageSize(maxSize); - } - }); - } + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { + @Override + public void run() { + realStream.setMaxInboundMessageSize(maxSize); + } + }); } @Override public void setMaxOutboundMessageSize(final int maxSize) { - if (passThrough) { - realStream.setMaxOutboundMessageSize(maxSize); - } else { - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.setMaxOutboundMessageSize(maxSize); - } - }); - } + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { + @Override + public void run() { + realStream.setMaxOutboundMessageSize(maxSize); + } + }); } @Override public void setDeadline(final Deadline deadline) { - delayOrExecute(new Runnable() { + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setDeadline(deadline); @@ -115,21 +113,41 @@ public void appendTimeoutInsight(InsightBuilder insight) { } /** - * Transfers all pending and future requests and mutations to the given stream. + * Transfers all pending and future requests and mutations to the given stream. Method will return + * quickly, but if the returned Runnable is non-null it must be called to complete the process. + * The Runnable may take a while to execute. * *

No-op if either this method or {@link #cancel} have already been called. */ - // When this method returns, passThrough is guaranteed to be true - final void setStream(ClientStream stream) { + // When this method returns, start() has been called on realStream or passThrough is guaranteed to + // be true + @CheckReturnValue + final Runnable setStream(ClientStream stream) { + ClientStreamListener savedListener; synchronized (this) { // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { - return; + return null; } setRealStream(checkNotNull(stream, "stream")); + savedListener = listener; + if (savedListener == null) { + assert pendingCalls.isEmpty(); + pendingCalls = null; + passThrough = true; + } + } + if (savedListener == null) { + return null; + } else { + internalStart(savedListener); + return new Runnable() { + @Override + public void run() { + drainPendingCalls(); + } + }; } - - drainPendingCalls(); } /** @@ -177,6 +195,7 @@ private void drainPendingCalls() { * only if {@code runnable} is thread-safe. */ private void delayOrExecute(Runnable runnable) { + checkState(listener != null, "May only be called after start"); synchronized (this) { if (!passThrough) { pendingCalls.add(runnable); @@ -190,7 +209,7 @@ private void delayOrExecute(Runnable runnable) { public void setAuthority(final String authority) { checkState(listener == null, "May only be called before start"); checkNotNull(authority, "authority"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setAuthority(authority); @@ -200,18 +219,19 @@ public void run() { @Override public void start(ClientStreamListener listener) { + checkNotNull(listener, "listener"); checkState(this.listener == null, "already started"); Status savedError; boolean savedPassThrough; synchronized (this) { - this.listener = checkNotNull(listener, "listener"); // If error != null, then cancel() has been called and was unable to close the listener savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { listener = delayedListener = new DelayedStreamListener(listener); } + this.listener = listener; startTimeNanos = System.nanoTime(); } if (savedError != null) { @@ -220,16 +240,20 @@ public void start(ClientStreamListener listener) { } if (savedPassThrough) { - realStream.start(listener); - } else { - final ClientStreamListener finalListener = listener; - delayOrExecute(new Runnable() { - @Override - public void run() { - realStream.start(finalListener); - } - }); + internalStart(listener); + } // else internalStart() will be called by setStream + } + + /** + * Starts stream without synchronization. {@code listener} should be same instance as {@link + * #listener}. + */ + private void internalStart(ClientStreamListener listener) { + for (Runnable runnable : preStartPendingCalls) { + runnable.run(); } + preStartPendingCalls = null; + realStream.start(listener); } @Override @@ -247,6 +271,7 @@ public Attributes getAttributes() { @Override public void writeMessage(final InputStream message) { + checkState(listener != null, "May only be called after start"); checkNotNull(message, "message"); if (passThrough) { realStream.writeMessage(message); @@ -262,6 +287,7 @@ public void run() { @Override public void flush() { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.flush(); } else { @@ -277,16 +303,14 @@ public void run() { // When this method returns, passThrough is guaranteed to be true @Override public void cancel(final Status reason) { + checkState(listener != null, "May only be called after start"); checkNotNull(reason, "reason"); boolean delegateToRealStream = true; - ClientStreamListener listenerToClose = null; synchronized (this) { // If realStream != null, then either setStream() or cancel() has been called if (realStream == null) { setRealStream(NoopClientStream.INSTANCE); delegateToRealStream = false; - // If listener == null, then start() will later call listener with 'error' - listenerToClose = listener; error = reason; } } @@ -298,10 +322,9 @@ public void run() { } }); } else { - if (listenerToClose != null) { - listenerToClose.closed(reason, new Metadata()); - } drainPendingCalls(); + // Note that listener is a DelayedStreamListener + listener.closed(reason, new Metadata()); } } @@ -314,6 +337,7 @@ private void setRealStream(ClientStream realStream) { @Override public void halfClose() { + checkState(listener != null, "May only be called after start"); delayOrExecute(new Runnable() { @Override public void run() { @@ -324,6 +348,7 @@ public void run() { @Override public void request(final int numMessages) { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.request(numMessages); } else { @@ -338,7 +363,8 @@ public void run() { @Override public void optimizeForDirectExecutor() { - delayOrExecute(new Runnable() { + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.optimizeForDirectExecutor(); @@ -348,8 +374,9 @@ public void run() { @Override public void setCompressor(final Compressor compressor) { + checkState(listener == null, "May only be called before start"); checkNotNull(compressor, "compressor"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setCompressor(compressor); @@ -359,7 +386,8 @@ public void run() { @Override public void setFullStreamDecompression(final boolean fullStreamDecompression) { - delayOrExecute( + checkState(listener == null, "May only be called before start"); + preStartPendingCalls.add( new Runnable() { @Override public void run() { @@ -370,8 +398,9 @@ public void run() { @Override public void setDecompressorRegistry(final DecompressorRegistry decompressorRegistry) { + checkState(listener == null, "May only be called before start"); checkNotNull(decompressorRegistry, "decompressorRegistry"); - delayOrExecute(new Runnable() { + preStartPendingCalls.add(new Runnable() { @Override public void run() { realStream.setDecompressorRegistry(decompressorRegistry); @@ -390,6 +419,7 @@ public boolean isReady() { @Override public void setMessageCompression(final boolean enable) { + checkState(listener != null, "May only be called after start"); if (passThrough) { realStream.setMessageCompression(enable); } else { diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index c3196ddd107..4c49a14a06b 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -95,7 +95,11 @@ private void finalizeWith(ClientStream stream) { // returnStream() has been called before me, thus delayedStream must have been // created. checkState(delayedStream != null, "delayedStream is null"); - delayedStream.setStream(stream); + Runnable slow = delayedStream.setStream(stream); + if (slow != null) { + // TODO(ejona): run this on a separate thread + slow.run(); + } } /** diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index f01b9d4d6fe..89655a7d7b5 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; @@ -161,7 +162,7 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - assertEquals(1, fakeExecutor.runDueTasks()); + assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); @@ -199,15 +200,6 @@ public void uncaughtException(Thread t, Throwable e) { any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); } - @Test public void cancelStreamWithoutSetTransport() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); - assertEquals(1, delayedTransport.getPendingStreamsCount()); - stream.cancel(Status.CANCELLED); - assertEquals(0, delayedTransport.getPendingStreamsCount()); - verifyNoMoreInteractions(mockRealTransport); - verifyNoMoreInteractions(mockRealStream); - } - @Test public void startThenCancelStreamWithoutSetTransport() { ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); stream.start(streamListener); @@ -258,6 +250,7 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); + stream.start(streamListener); stream.cancel(Status.CANCELLED); verify(transportListener).transportTerminated(); assertEquals(0, delayedTransport.getPendingStreamsCount()); @@ -282,7 +275,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - verify(streamListener).closed(statusCaptor.capture(), any(Metadata.class)); + verify(streamListener) + .closed(statusCaptor.capture(), eq(RpcProgress.REFUSED), any(Metadata.class)); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } @@ -316,6 +310,8 @@ public void uncaughtException(Thread t, Throwable e) { // Fail-fast streams DelayedStream ff1 = (DelayedStream) delayedTransport.newStream( method, headers, failFastCallOptions); + ff1.start(mock(ClientStreamListener.class)); + ff1.halfClose(); PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); @@ -346,6 +342,8 @@ public void uncaughtException(Thread t, Throwable e) { wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( method, headers, wfr3callOptions); + wfr3.start(mock(ClientStreamListener.class)); + wfr3.halfClose(); PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( @@ -381,18 +379,22 @@ public void uncaughtException(Thread t, Throwable e) { inOrder.verify(picker).pickSubchannel(wfr4args); inOrder.verifyNoMoreInteractions(); - // Make sure that real transport creates streams in the executor - verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - verify(mockRealTransport2, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - fakeExecutor.runDueTasks(); - assertEquals(0, fakeExecutor.numPendingTasks()); + // Make sure that streams are created and started immediately, not in any executor. This is + // necessary during shut down to guarantee that when DelayedClientTransport terminates, all + // streams are now owned by a real transport (which should prevent the Channel from + // terminating). // ff1 and wfr1 went through verify(mockRealTransport).newStream(method, headers, failFastCallOptions); verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); assertSame(mockRealStream, ff1.getRealStream()); assertSame(mockRealStream2, wfr1.getRealStream()); + verify(mockRealStream).start(any(ClientStreamListener.class)); + // But also verify that non-start()-related calls are run within the Executor, since they may be + // slow. + verify(mockRealStream, never()).halfClose(); + fakeExecutor.runDueTasks(); + assertEquals(0, fakeExecutor.numPendingTasks()); + verify(mockRealStream).halfClose(); // The ff2 has failed due to picker returning an error assertSame(Status.UNAVAILABLE, ((FailingClientStream) ff2.getRealStream()).getError()); // Other streams are still buffered @@ -431,10 +433,11 @@ public void uncaughtException(Thread t, Throwable e) { assertSame(mockRealStream2, wfr2.getRealStream()); assertSame(mockRealStream2, wfr4.getRealStream()); + assertSame(mockRealStream, wfr3.getRealStream()); // If there is an executor in the CallOptions, it will be used to create the real stream. - assertNull(wfr3.getRealStream()); + verify(mockRealStream, times(1)).halfClose(); // 1 for ff1 wfr3Executor.runDueTasks(); - assertSame(mockRealStream, wfr3.getRealStream()); + verify(mockRealStream, times(2)).halfClose(); // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 393a6c6e6d0..f3126ba6f42 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -72,7 +72,7 @@ public void setStream_setAuthority() { final String authority = "becauseIsaidSo"; stream.setAuthority(authority); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); InOrder inOrder = inOrder(realStream); inOrder.verify(realStream).setAuthority(authority); inOrder.verify(realStream).start(any(ClientStreamListener.class)); @@ -92,9 +92,9 @@ public void start_afterStart() { @Test public void setStream_sendsAllMessages() { - stream.start(listener); stream.setCompressor(Codec.Identity.NONE); stream.setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); + stream.start(listener); stream.setMessageCompression(true); InputStream message = new ByteArrayInputStream(new byte[]{'a'}); @@ -102,7 +102,7 @@ public void setStream_sendsAllMessages() { stream.setMessageCompression(false); stream.writeMessage(message); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).setCompressor(Codec.Identity.NONE); verify(realStream).setDecompressorRegistry(DecompressorRegistry.getDefaultInstance()); @@ -125,7 +125,7 @@ public void setStream_sendsAllMessages() { public void setStream_halfClose() { stream.start(listener); stream.halfClose(); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).halfClose(); } @@ -134,7 +134,7 @@ public void setStream_halfClose() { public void setStream_flush() { stream.start(listener); stream.flush(); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).flush(); stream.flush(); @@ -146,7 +146,7 @@ public void setStream_flowControl() { stream.start(listener); stream.request(1); stream.request(2); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).request(1); verify(realStream).request(2); @@ -158,7 +158,7 @@ public void setStream_flowControl() { public void setStream_setMessageCompression() { stream.start(listener); stream.setMessageCompression(false); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).setMessageCompression(false); stream.setMessageCompression(true); @@ -169,7 +169,7 @@ public void setStream_setMessageCompression() { public void setStream_isReady() { stream.start(listener); assertFalse(stream.isReady()); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream, never()).isReady(); assertFalse(stream.isReady()); @@ -190,7 +190,7 @@ public void setStream_getAttributes() { assertEquals(Attributes.EMPTY, stream.getAttributes()); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); assertEquals(attributes, stream.getAttributes()); } @@ -204,7 +204,7 @@ public void startThenCancelled() { @Test public void startThenSetStreamThenCancelled() { stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); stream.cancel(Status.CANCELLED); verify(realStream).start(any(ClientStreamListener.class)); verify(realStream).cancel(same(Status.CANCELLED)); @@ -212,52 +212,36 @@ public void startThenSetStreamThenCancelled() { @Test public void setStreamThenStartThenCancelled() { - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); stream.start(listener); stream.cancel(Status.CANCELLED); verify(realStream).start(same(listener)); verify(realStream).cancel(same(Status.CANCELLED)); } - @Test - public void setStreamThenCancelled() { - stream.setStream(realStream); - stream.cancel(Status.CANCELLED); - verify(realStream).cancel(same(Status.CANCELLED)); - } - @Test public void setStreamTwice() { stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - stream.setStream(mock(ClientStream.class)); + callMeMaybe(stream.setStream(mock(ClientStream.class))); stream.flush(); verify(realStream).flush(); } @Test public void cancelThenSetStream() { - stream.cancel(Status.CANCELLED); - stream.setStream(realStream); stream.start(listener); + stream.cancel(Status.CANCELLED); + callMeMaybe(stream.setStream(realStream)); stream.isReady(); verifyNoMoreInteractions(realStream); } - @Test + @Test(expected = IllegalStateException.class) public void cancel_beforeStart() { Status status = Status.CANCELLED.withDescription("that was quick"); stream.cancel(status); - stream.start(listener); - verify(listener).closed(same(status), any(Metadata.class)); - } - - @Test - public void cancelledThenStart() { - stream.cancel(Status.CANCELLED); - stream.start(listener); - verify(listener).closed(eq(Status.CANCELLED), any(Metadata.class)); } @Test @@ -275,7 +259,7 @@ public void onReady() { IsReadyListener isReadyListener = new IsReadyListener(); stream.start(isReadyListener); - stream.setStream(new NoopClientStream() { + callMeMaybe(stream.setStream(new NoopClientStream() { @Override public void start(ClientStreamListener listener) { // This call to the listener should end up being delayed. @@ -286,7 +270,7 @@ public void start(ClientStreamListener listener) { public boolean isReady() { return true; } - }); + })); assertTrue(isReadyListener.onReadyCalled); } @@ -302,7 +286,7 @@ public void listener_allQueued() { final InOrder inOrder = inOrder(listener); stream.start(listener); - stream.setStream(new NoopClientStream() { + callMeMaybe(stream.setStream(new NoopClientStream() { @Override public void start(ClientStreamListener passedListener) { passedListener.onReady(); @@ -314,7 +298,7 @@ public void start(ClientStreamListener passedListener) { verifyNoMoreInteractions(listener); } - }); + })); inOrder.verify(listener).onReady(); inOrder.verify(listener).headersRead(headers); inOrder.verify(listener).messagesAvailable(producer1); @@ -332,7 +316,7 @@ public void listener_noQueued() { final Status status = Status.UNKNOWN.withDescription("unique status"); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); verify(realStream).start(listenerCaptor.capture()); ClientStreamListener delayedListener = listenerCaptor.getValue(); delayedListener.onReady(); @@ -371,11 +355,17 @@ public Void answer(InvocationOnMock in) { } }).when(realStream).appendTimeoutInsight(any(InsightBuilder.class)); stream.start(listener); - stream.setStream(realStream); + callMeMaybe(stream.setStream(realStream)); InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) .matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]"); } + + private void callMeMaybe(Runnable r) { + if (r != null) { + r.run(); + } + } }