Skip to content

Commit

Permalink
core: delay CallCredentialsApplyingTransport shutdown until metadataA…
Browse files Browse the repository at this point in the history
…pplier finalized (#7813)

Improve the CallCredentialsApplyingTransport shutdown lifecycle management. Right now CallCredentialsApplyingTransport shutdown the delegated real transport too early. It should be waiting for the metadataAppliers to finish because they may execute asynchronously. In addition, there is no shutdown check on CallCredentialsApplyingTransport for newStream(). The degraded lifecycle implementation may cause RejectionExecutionException, or accepting new RPCs after the underlying transport is already closed during channel shutdown.

We added listener on metadataApplier to notify completion, a magic counter to track the pending metadataApplier for delaying shutdown, also added shutdown check for newStream().
  • Loading branch information
YifeiZhuang committed Jan 26, 2021
1 parent dbd903c commit ac2ead7
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 3 deletions.
Expand Up @@ -29,9 +29,12 @@
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
Expand Down Expand Up @@ -66,6 +69,21 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
// Negative value means transport active, non-negative value indicates shutdown invoked.
private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1);
private volatile Status shutdownStatus;
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;
private final MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -89,7 +107,11 @@ public ClientStream newStream(
}
if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
if (pendingApplier.incrementAndGet() > 0) {
applierListener.onComplete();
return new FailingClientStream(shutdownStatus);
}
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -123,8 +145,69 @@ public Attributes getTransportAttrs() {
}
return applier.returnStream();
} else {
if (pendingApplier.get() >= 0) {
return new FailingClientStream(shutdownStatus);
}
return delegate.newStream(method, headers, callOptions);
}
}

@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
return;
}
}
super.shutdown(status);
}

// TODO(zivy): cancel pending applier here.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else if (savedShutdownNowStatus != null) {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
// TODO(zivy): propagate shutdownNow to the delegate immediately.
return;
}
}
super.shutdownNow(status);
}

private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
if (pendingApplier.get() != 0) {
return;
}
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
if (maybeShutdown != null) {
super.shutdown(maybeShutdown);
}
if (maybeShutdownNow != null) {
super.shutdownNow(maybeShutdownNow);
}
}
}
}
19 changes: 17 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;

private final Object lock = new Object();

Expand All @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}

@Override
Expand Down Expand Up @@ -84,14 +86,19 @@ public void fail(Status status) {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
Expand All @@ -100,6 +107,7 @@ private void finalizeWith(ClientStream stream) {
// TODO(ejona): run this on a separate thread
slow.run();
}
listener.onComplete();
}

/**
Expand All @@ -116,4 +124,11 @@ ClientStream returnStream() {
}
}
}

public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -203,6 +204,10 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -227,6 +232,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -249,6 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

@Test
Expand All @@ -263,6 +276,9 @@ public void applyMetadata_delayed() {
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
Expand All @@ -271,6 +287,9 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -290,6 +309,10 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -301,5 +324,9 @@ public void noCreds() {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}

0 comments on commit ac2ead7

Please sign in to comment.