Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized #7813

Merged
merged 8 commits into from Jan 26, 2021
Expand Up @@ -29,9 +29,12 @@
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
Expand Down Expand Up @@ -66,6 +69,21 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
// Negative value means transport active, non-negative value indicates shutdown invoked.
private final AtomicInteger pendingApplier = new AtomicInteger(Integer.MIN_VALUE + 1);
private volatile Status shutdownStatus;
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;
private final MetadataApplierListener applierListener = new MetadataApplierListener() {
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -89,7 +107,11 @@ public ClientStream newStream(
}
if (creds != null) {
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
if (pendingApplier.incrementAndGet() > 0) {
applierListener.onComplete();
return new FailingClientStream(shutdownStatus);
}
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -123,8 +145,69 @@ public Attributes getTransportAttrs() {
}
return applier.returnStream();
} else {
if (pendingApplier.get() >= 0) {
return new FailingClientStream(shutdownStatus);
}
return delegate.newStream(method, headers, callOptions);
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
}
}

@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
super.shutdown(status);
}

// TODO(zivy): cancel pending applier here.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() < 0) {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
shutdownStatus = status;
pendingApplier.addAndGet(Integer.MAX_VALUE);
} else if (savedShutdownNowStatus != null) {
return;
}
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
// TODO(zivy): propagate shutdownNow to the delegate immediately.
return;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
}
}
super.shutdownNow(status);
}

private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
if (pendingApplier.get() != 0) {
return;
}
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
if (maybeShutdown != null) {
super.shutdown(maybeShutdown);
}
if (maybeShutdownNow != null) {
super.shutdownNow(maybeShutdownNow);
}
}
}
}
19 changes: 17 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;

private final Object lock = new Object();

Expand All @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}

@Override
Expand Down Expand Up @@ -84,14 +86,19 @@ public void fail(Status status) {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
Expand All @@ -100,6 +107,7 @@ private void finalizeWith(ClientStream stream) {
// TODO(ejona): run this on a separate thread
slow.run();
}
listener.onComplete();
}

/**
Expand All @@ -116,4 +124,11 @@ ClientStream returnStream() {
}
}
}

public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}
Expand Up @@ -19,6 +19,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -203,6 +204,10 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -227,6 +232,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -249,6 +258,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

@Test
Expand All @@ -263,6 +276,9 @@ public void applyMetadata_delayed() {
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
Expand All @@ -271,6 +287,9 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -290,6 +309,10 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -301,5 +324,9 @@ public void noCreds() {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
assertTrue(transport.newStream(method, origHeaders, callOptions)
instanceof FailingClientStream);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}