Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized #7813

Merged
merged 8 commits into from Jan 26, 2021
Expand Up @@ -33,7 +33,8 @@
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
Expand Down Expand Up @@ -69,11 +70,22 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
private final AtomicLong pendingApplier = new AtomicLong(0);
// Negative value means transport active, non-negative value indicates shutdown invoked.
private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1);
private final AtomicBoolean shutdownInvoked = new AtomicBoolean(false);
private final Status shutdownStatus = Status.UNAVAILABLE;
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;
private final MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -89,24 +101,22 @@ protected ConnectionClientTransport delegate() {
@SuppressWarnings("deprecation")
public ClientStream newStream(
final MethodDescriptor<?, ?> method, Metadata headers, final CallOptions callOptions) {
if (shutdownInvoked.get()) {
return new FailingClientStream(shutdownStatus);
}
CallCredentials creds = callOptions.getCredentials();
if (creds == null) {
creds = channelCallCredentials;
} else if (channelCallCredentials != null) {
creds = new CompositeCallCredentials(channelCallCredentials, creds);
}
if (creds != null) {
MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions, applierListener);
pendingApplier.incrementAndGet();
if (pendingApplier.incrementAndGet() > 0) {
applierListener.onComplete();
return new FailingClientStream(shutdownStatus);
}
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -147,6 +157,9 @@ public Attributes getTransportAttrs() {
@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
if (shutdownInvoked.compareAndSet(false, true)) {
pendingApplier.addAndGet(Integer.MAX_VALUE);
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
}
synchronized (this) {
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -156,10 +169,13 @@ public void shutdown(Status status) {
super.shutdown(status);
}

// TODO(zivy@): add cancel pending applier.
// TODO(zivy@): should call delegate shutdownNow asap. Maybe cancel pending applier.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
if (shutdownInvoked.compareAndSet(false, true)) {
pendingApplier.addAndGet(Integer.MAX_VALUE);
}
synchronized (this) {
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
Expand All @@ -173,11 +189,11 @@ private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
if ((maybeShutdown == null && maybeShutdownNow == null) || pendingApplier.get() != 0) {
if (pendingApplier.get() != 0) {
return;
}
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
Expand Down
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -204,6 +205,8 @@ public void credentialThrows() {
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -230,6 +233,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -253,8 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

@Test
Expand All @@ -280,6 +287,8 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -301,6 +310,8 @@ public void fail_delayed() {
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -314,6 +325,8 @@ public void noCreds() {
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -182,6 +183,8 @@ public void credentialThrows() {
assertSame(ex, stream.getError().getCause());

transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -197,6 +200,8 @@ public void applyMetadata_inline() {
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -221,6 +226,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

Expand All @@ -238,6 +245,8 @@ public void applyMetadata_delayed() {

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
Expand All @@ -258,11 +267,15 @@ public void delayedShutdown_shutdownShutdownNowThenApply() {
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(any(Status.class));
verify(mockTransport, never()).shutdownNow(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdownNow(Status.ABORTED);
}
Expand All @@ -274,6 +287,8 @@ public void delayedShutdown_shutdownThenApplyThenShutdownNow() {
verify(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
Expand All @@ -296,12 +311,18 @@ public void delayedShutdown_shutdownMulti() {
same(mockExecutor), applierCaptor.capture());
applierCaptor.getAllValues().get(1).apply(headers);
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

applierCaptor.getAllValues().get(0).apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

applierCaptor.getAllValues().get(2).apply(headers);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -323,6 +344,8 @@ public void fail_delayed() {
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand All @@ -336,6 +359,8 @@ public void noCreds() {
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

Expand Down