Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized #7813

Merged
merged 8 commits into from Jan 26, 2021
Expand Up @@ -29,9 +29,12 @@
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.concurrent.GuardedBy;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
Expand Down Expand Up @@ -66,6 +69,11 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
private final AtomicLong pendingApplier = new AtomicLong(0);
@GuardedBy("this")
private Status savedShutdownStatus;
@GuardedBy("this")
private Status savedShutdownNowStatus;

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -88,8 +96,17 @@ public ClientStream newStream(
creds = new CompositeCallCredentials(channelCallCredentials, creds);
}
if (creds != null) {
MetadataApplierListener applierListener = new MetadataApplierListener() {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
maybeShutdown();
}
}
};
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
pendingApplier.incrementAndGet();
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -126,5 +143,50 @@ public Attributes getTransportAttrs() {
return delegate.newStream(method, headers, callOptions);
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
}
}

@Override
public void shutdown(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() != 0) {
savedShutdownStatus = status;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
super.shutdown(status);
}

// TODO(zivy@): add cancel pending applier.
@Override
public void shutdownNow(Status status) {
checkNotNull(status, "status");
synchronized (this) {
if (pendingApplier.get() != 0) {
savedShutdownNowStatus = status;
return;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
}
}
super.shutdownNow(status);
}

private void maybeShutdown() {
Status maybeShutdown;
Status maybeShutdownNow;
synchronized (this) {
maybeShutdown = savedShutdownStatus;
maybeShutdownNow = savedShutdownNowStatus;
if ((maybeShutdown == null && maybeShutdownNow == null) || pendingApplier.get() != 0) {
return;
}
savedShutdownStatus = null;
savedShutdownNowStatus = null;
}
if (maybeShutdown != null) {
super.shutdown(maybeShutdown);
}
if (maybeShutdownNow != null) {
super.shutdownNow(maybeShutdownNow);
}
}
}
}
19 changes: 17 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;

private final Object lock = new Object();

Expand All @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}

@Override
Expand Down Expand Up @@ -84,18 +86,24 @@ public void fail(Status status) {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
delayedStream.setStream(stream);
listener.onComplete();
}

/**
Expand All @@ -112,4 +120,11 @@ ClientStream returnStream() {
}
}
}

public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}
Expand Up @@ -203,6 +203,8 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -227,6 +229,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -249,6 +253,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -263,6 +269,9 @@ public void applyMetadata_delayed() {
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
Expand All @@ -271,6 +280,7 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -290,6 +300,8 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -301,5 +313,7 @@ public void noCreds() {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}
}
Expand Up @@ -25,6 +25,7 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -179,6 +180,9 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -192,6 +196,8 @@ public void applyMetadata_inline() {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -214,6 +220,8 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
transport.shutdownNow(Status.UNAVAILABLE);
verify(mockTransport).shutdownNow(Status.UNAVAILABLE);
}

@Test
Expand All @@ -228,6 +236,9 @@ public void applyMetadata_delayed() {
same(mockExecutor), applierCaptor.capture());
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);

transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
Expand All @@ -236,6 +247,62 @@ public void applyMetadata_delayed() {
assertSame(mockStream, stream.getRealStream());
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
public void delayedShutdown_shutdownShutdownNowThenApply() {
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
verify(mockTransport, never()).shutdown(any(Status.class));
verify(mockTransport, never()).shutdownNow(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdownNow(Status.ABORTED);
}

@Test
public void delayedShutdown_shutdownThenApplyThenShutdownNow() {
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(any(Status.class));
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);
applierCaptor.getValue().apply(headers);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
transport.shutdownNow(Status.ABORTED);
verify(mockTransport).shutdownNow(Status.ABORTED);
}

@Test
public void delayedShutdown_shutdownMulti() {
Metadata headers = new Metadata();
headers.put(CREDS_KEY, CREDS_VALUE);

transport.newStream(method, origHeaders, callOptions);
transport.newStream(method, origHeaders, callOptions);
transport.newStream(method, origHeaders, callOptions);
ArgumentCaptor<CallCredentials.MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class),
same(mockExecutor), applierCaptor.capture());
applierCaptor.getAllValues().get(1).apply(headers);
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

applierCaptor.getAllValues().get(0).apply(headers);
verify(mockTransport, never()).shutdown(Status.UNAVAILABLE);

applierCaptor.getAllValues().get(2).apply(headers);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -255,6 +322,8 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand All @@ -266,6 +335,8 @@ public void noCreds() {
assertSame(mockStream, stream);
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
transport.shutdown(Status.UNAVAILABLE);
verify(mockTransport).shutdown(Status.UNAVAILABLE);
}

@Test
Expand Down