Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: delay CallCredentialsApplyingTransport shutdown until metadataApplier finalized #7813

Merged
merged 8 commits into from Jan 26, 2021
Expand Up @@ -29,9 +29,11 @@
import io.grpc.MethodDescriptor;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.MetadataApplierImpl.MetadataApplierListener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicLong;

final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory {
private final ClientTransportFactory delegate;
Expand Down Expand Up @@ -66,6 +68,8 @@ public void close() {
private class CallCredentialsApplyingTransport extends ForwardingConnectionClientTransport {
private final ConnectionClientTransport delegate;
private final String authority;
private final AtomicLong pendingApplier = new AtomicLong(0);
private volatile CallCredentialsApplyingTransportListener listener;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved

CallCredentialsApplyingTransport(ConnectionClientTransport delegate, String authority) {
this.delegate = checkNotNull(delegate, "delegate");
Expand All @@ -81,15 +85,25 @@ protected ConnectionClientTransport delegate() {
@SuppressWarnings("deprecation")
public ClientStream newStream(
final MethodDescriptor<?, ?> method, Metadata headers, final CallOptions callOptions) {
checkNotNull(listener, "listener");
CallCredentials creds = callOptions.getCredentials();
if (creds == null) {
creds = channelCallCredentials;
} else if (channelCallCredentials != null) {
creds = new CompositeCallCredentials(channelCallCredentials, creds);
}
if (creds != null) {
MetadataApplierListener applierListener = new MetadataApplierListener() {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
@Override
public void onComplete() {
if (pendingApplier.decrementAndGet() == 0) {
listener.maybeTerminated();
}
}
};
MetadataApplierImpl applier = new MetadataApplierImpl(
delegate, method, headers, callOptions);
delegate, method, headers, callOptions, applierListener);
pendingApplier.incrementAndGet();
RequestInfo requestInfo = new RequestInfo() {
@Override
public MethodDescriptor<?, ?> getMethodDescriptor() {
Expand Down Expand Up @@ -126,5 +140,50 @@ public Attributes getTransportAttrs() {
return delegate.newStream(method, headers, callOptions);
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
}
}

@Override
public Runnable start(Listener listener) {
this.listener = new CallCredentialsApplyingTransportListener(listener);
return super.start(this.listener);
}
// TODO(zivy@): cancel pending appliers when shutdownNow.

private class CallCredentialsApplyingTransportListener implements Listener {
private final Listener delegateListener;
private volatile boolean savedTransportTerminated;

public CallCredentialsApplyingTransportListener(Listener listener) {
this.delegateListener = listener;
}

@Override
public void transportShutdown(Status s) {
delegateListener.transportShutdown(s);
}

@Override
public void transportTerminated() {
savedTransportTerminated = true;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
maybeTerminated();
}

@Override
public void transportReady() {
delegateListener.transportReady();
}

@Override
public void transportInUse(boolean inUse) {
delegateListener.transportInUse(inUse);
}

public void maybeTerminated() {
synchronized (this) {
if (pendingApplier.get() == 0 && savedTransportTerminated) {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
delegateListener.transportTerminated();
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
}
}
19 changes: 17 additions & 2 deletions core/src/main/java/io/grpc/internal/MetadataApplierImpl.java
Expand Up @@ -35,6 +35,7 @@ final class MetadataApplierImpl extends MetadataApplier {
private final Metadata origHeaders;
private final CallOptions callOptions;
private final Context ctx;
private final MetadataApplierListener listener;

private final Object lock = new Object();

Expand All @@ -51,12 +52,13 @@ final class MetadataApplierImpl extends MetadataApplier {

MetadataApplierImpl(
ClientTransport transport, MethodDescriptor<?, ?> method, Metadata origHeaders,
CallOptions callOptions) {
CallOptions callOptions, MetadataApplierListener listener) {
this.transport = transport;
this.method = method;
this.origHeaders = origHeaders;
this.callOptions = callOptions;
this.ctx = Context.current();
this.listener = listener;
}

@Override
Expand Down Expand Up @@ -84,18 +86,24 @@ public void fail(Status status) {
private void finalizeWith(ClientStream stream) {
checkState(!finalized, "already finalized");
finalized = true;
boolean directStream = false;
synchronized (lock) {
if (returnedStream == null) {
// Fast path: returnStream() hasn't been called, the call will use the
// real stream directly.
returnedStream = stream;
return;
directStream = true;
}
}
if (directStream) {
listener.onComplete();
return;
}
// returnStream() has been called before me, thus delayedStream must have been
// created.
checkState(delayedStream != null, "delayedStream is null");
delayedStream.setStream(stream);
listener.onComplete();
}

/**
Expand All @@ -112,4 +120,11 @@ ClientStream returnStream() {
}
}
}

public interface MetadataApplierListener {
/**
* Notify that the metadata has been successfully applied, or failed.
* */
void onComplete();
}
}
Expand Up @@ -17,6 +17,7 @@
package io.grpc.internal;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -39,6 +40,7 @@
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.StringMarshaller;
import io.grpc.internal.ManagedClientTransport.Listener;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import org.junit.Before;
Expand Down Expand Up @@ -75,6 +77,9 @@ public class CallCredentials2ApplyingTest {
@Mock
private io.grpc.CallCredentials2 mockCreds;

@Mock
private Listener mockTransportListener;

@Mock
private Executor mockExecutor;

Expand Down Expand Up @@ -123,6 +128,7 @@ public void setUp() {
mockTransportFactory, null, mockExecutor);
transport = (ForwardingConnectionClientTransport)
transportFactory.newClientTransport(address, clientTransportOptions, channelLogger);
transport.start(mockTransportListener);
callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds);
verify(mockTransportFactory).newClientTransport(address, clientTransportOptions, channelLogger);
assertSame(mockTransport, transport.delegate());
Expand Down Expand Up @@ -203,6 +209,11 @@ public void credentialThrows() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode());
assertSame(ex, stream.getError().getCause());

ArgumentCaptor<Listener> listenerCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).start(listenerCaptor.capture());
listenerCaptor.getValue().transportTerminated();
verify(mockTransportListener).transportTerminated();
}

@Test
Expand All @@ -227,6 +238,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
assertSame(mockStream, stream);
assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
ArgumentCaptor<Listener> listenerCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).start(listenerCaptor.capture());
listenerCaptor.getValue().transportTerminated();
verify(mockTransportListener).transportTerminated();
}

@Test
Expand All @@ -249,6 +264,10 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
assertSame(error, stream.getError());
ArgumentCaptor<Listener> listenerCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).start(listenerCaptor.capture());
listenerCaptor.getValue().transportTerminated();
verify(mockTransportListener).transportTerminated();
}

@Test
Expand Down Expand Up @@ -290,6 +309,10 @@ public void fail_delayed() {
verify(mockTransport, never()).newStream(method, origHeaders, callOptions);
FailingClientStream failingStream = (FailingClientStream) stream.getRealStream();
assertSame(error, failingStream.getError());
ArgumentCaptor<Listener> listenerCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).start(listenerCaptor.capture());
listenerCaptor.getValue().transportTerminated();
verify(mockTransportListener).transportTerminated();
}

@Test
Expand All @@ -302,4 +325,52 @@ public void noCreds() {
assertNull(origHeaders.get(CREDS_KEY));
assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY));
}

@Test
public void applierTransportListenerTest_base() {
ArgumentCaptor<Listener> captureListener = ArgumentCaptor.forClass(null);
verify(mockTransport).start(captureListener.capture());
Listener applierTransportListener = captureListener.getValue();
assertNotSame(applierTransportListener, mockTransportListener);
applierTransportListener.transportShutdown(Status.UNKNOWN);
verify(mockTransportListener).transportShutdown(Status.UNKNOWN);
applierTransportListener.transportReady();
verify(mockTransportListener).transportReady();
applierTransportListener.transportInUse(true);
verify(mockTransportListener).transportInUse(true);
applierTransportListener.transportTerminated();
verify(mockTransportListener).transportTerminated();
}

@Test
public void applierTransportListenerTest_terminateThenApply() {
transport.newStream(method, origHeaders, callOptions);

ArgumentCaptor<Listener> captureListener = ArgumentCaptor.forClass(null);
verify(mockTransport).start(captureListener.capture());
captureListener.getValue().transportTerminated();
verify(mockTransportListener, never()).transportTerminated();

ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
applierCaptor.getValue().apply(new Metadata());
verify(mockTransportListener).transportTerminated();
}

@Test
public void applierTransportListenerTest_applyThenTerminate() {
transport.newStream(method, origHeaders, callOptions);

ArgumentCaptor<MetadataApplier> applierCaptor = ArgumentCaptor.forClass(null);
verify(mockCreds).applyRequestMetadata(
any(RequestInfo.class), same(mockExecutor), applierCaptor.capture());
applierCaptor.getValue().apply(new Metadata());
verify(mockTransportListener, never()).transportTerminated();

ArgumentCaptor<Listener> listenerCaptor = ArgumentCaptor.forClass(null);
verify(mockTransport).start(listenerCaptor.capture());
listenerCaptor.getValue().transportTerminated();
verify(mockTransportListener).transportTerminated();
}
}