Skip to content

Commit

Permalink
api,core: change ManagedChannel and Server Builders to use GlobalInte…
Browse files Browse the repository at this point in the history
…rceptors (#9312)

* api,core: change ManagedChannel and Server Builders to use GlobalInterceptors
also added a getter in GlobalInterceptors to expose the set flag
  • Loading branch information
sanjaypujare committed Jun 30, 2022
1 parent 6271bab commit c9a52eb
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 26 deletions.
26 changes: 11 additions & 15 deletions api/src/main/java/io/grpc/GlobalInterceptors.java
Expand Up @@ -16,17 +16,19 @@

package io.grpc;

import static com.google.common.base.Preconditions.checkNotNull;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/** The collection of global interceptors and global server stream tracers. */
@Internal
final class GlobalInterceptors {
private static List<ClientInterceptor> clientInterceptors = Collections.emptyList();
private static List<ServerInterceptor> serverInterceptors = Collections.emptyList();
private static List<ClientInterceptor> clientInterceptors = null;
private static List<ServerInterceptor> serverInterceptors = null;
private static List<ServerStreamTracer.Factory> serverStreamTracerFactories =
Collections.emptyList();
null;
private static boolean isGlobalInterceptorsTracersSet;
private static boolean isGlobalInterceptorsTracersGet;

Expand Down Expand Up @@ -61,19 +63,13 @@ static synchronized void setInterceptorsTracers(
if (isGlobalInterceptorsTracersSet) {
throw new IllegalStateException("Global interceptors and tracers are already set");
}

if (clientInterceptorList != null) {
clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList));
}

if (serverInterceptorList != null) {
serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList));
}

if (serverStreamTracerFactoryList != null) {
serverStreamTracerFactories =
checkNotNull(clientInterceptorList);
checkNotNull(serverInterceptorList);
checkNotNull(serverStreamTracerFactoryList);
clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList));
serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList));
serverStreamTracerFactories =
Collections.unmodifiableList(new ArrayList<>(serverStreamTracerFactoryList));
}
isGlobalInterceptorsTracersSet = true;
}

Expand Down
9 changes: 5 additions & 4 deletions api/src/test/java/io/grpc/GlobalInterceptorsTest.java
Expand Up @@ -21,6 +21,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;
import org.junit.Test;
Expand Down Expand Up @@ -98,7 +99,7 @@ public static final class StaticTestingClassLoaderSetTwice implements Runnable {
public void run() {
GlobalInterceptors.setInterceptorsTracers(
new ArrayList<>(Arrays.asList(new NoopClientInterceptor())),
null,
Collections.emptyList(),
new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory())));
try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -115,7 +116,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetClientInterceptor
@Override
public void run() {
List<ClientInterceptor> clientInterceptors = GlobalInterceptors.getClientInterceptors();
assertThat(clientInterceptors).isEmpty();
assertThat(clientInterceptors).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -132,7 +133,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetServerInterceptor
@Override
public void run() {
List<ServerInterceptor> serverInterceptors = GlobalInterceptors.getServerInterceptors();
assertThat(serverInterceptors).isEmpty();
assertThat(serverInterceptors).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -150,7 +151,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetServerStreamTracer
public void run() {
List<ServerStreamTracer.Factory> serverStreamTracerFactories =
GlobalInterceptors.getServerStreamTracerFactories();
assertThat(serverStreamTracerFactories).isEmpty();
assertThat(serverStreamTracerFactories).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand Down
15 changes: 11 additions & 4 deletions core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java
Expand Up @@ -31,6 +31,7 @@
import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalChannelz;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.NameResolver;
Expand Down Expand Up @@ -636,9 +637,15 @@ public ManagedChannel build() {
// TODO(zdapeng): FIX IT
@VisibleForTesting
List<ClientInterceptor> getEffectiveInterceptors() {
List<ClientInterceptor> effectiveInterceptors =
new ArrayList<>(this.interceptors);
if (statsEnabled) {
List<ClientInterceptor> effectiveInterceptors = new ArrayList<>(this.interceptors);
boolean isGlobalInterceptorsSet = false;
List<ClientInterceptor> globalClientInterceptors =
InternalGlobalInterceptors.getClientInterceptors();
if (globalClientInterceptors != null) {
effectiveInterceptors.addAll(globalClientInterceptors);
isGlobalInterceptorsSet = true;
}
if (!isGlobalInterceptorsSet && statsEnabled) {
ClientInterceptor statsInterceptor = null;
try {
Class<?> censusStatsAccessor =
Expand Down Expand Up @@ -674,7 +681,7 @@ List<ClientInterceptor> getEffectiveInterceptors() {
effectiveInterceptors.add(0, statsInterceptor);
}
}
if (tracingEnabled) {
if (!isGlobalInterceptorsSet && tracingEnabled) {
ClientInterceptor tracingInterceptor = null;
try {
Class<?> censusTracingAccessor =
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/java/io/grpc/internal/ServerImplBuilder.java
Expand Up @@ -30,6 +30,7 @@
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.InternalChannelz;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCallExecutorSupplier;
Expand Down Expand Up @@ -246,7 +247,17 @@ public Server build() {
@VisibleForTesting
List<? extends ServerStreamTracer.Factory> getTracerFactories() {
ArrayList<ServerStreamTracer.Factory> tracerFactories = new ArrayList<>();
if (statsEnabled) {
boolean isGlobalInterceptorsTracersSet = false;
List<ServerInterceptor> globalServerInterceptors
= InternalGlobalInterceptors.getServerInterceptors();
List<ServerStreamTracer.Factory> globalServerStreamTracerFactories
= InternalGlobalInterceptors.getServerStreamTracerFactories();
if (globalServerInterceptors != null) {
tracerFactories.addAll(globalServerStreamTracerFactories);
interceptors.addAll(globalServerInterceptors);
isGlobalInterceptorsTracersSet = true;
}
if (!isGlobalInterceptorsTracersSet && statsEnabled) {
ServerStreamTracer.Factory censusStatsTracerFactory = null;
try {
Class<?> censusStatsAccessor =
Expand Down Expand Up @@ -278,7 +289,7 @@ List<? extends ServerStreamTracer.Factory> getTracerFactories() {
tracerFactories.add(censusStatsTracerFactory);
}
}
if (tracingEnabled) {
if (!isGlobalInterceptorsTracersSet && tracingEnabled) {
ServerStreamTracer.Factory tracingStreamTracerFactory = null;
try {
Class<?> censusTracingAccessor =
Expand Down
Expand Up @@ -35,9 +35,11 @@
import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.NameResolver;
import io.grpc.StaticTestingClassLoader;
import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider;
Expand All @@ -47,12 +49,14 @@
import java.net.SocketAddress;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -78,6 +82,14 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
return next.newCall(method, callOptions);
}
};
private static final ClientInterceptor DUMMY_USER_INTERCEPTOR1 =
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return next.newCall(method, callOptions);
}
};

@Rule public final MockitoRule mocks = MockitoJUnit.rule();
@SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
Expand All @@ -90,7 +102,12 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
private ManagedChannelImplBuilder builder;
private ManagedChannelImplBuilder directAddressBuilder;
private final FakeClock clock = new FakeClock();

private final StaticTestingClassLoader classLoader =
new StaticTestingClassLoader(
getClass().getClassLoader(),
Pattern.compile(
"io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|"
+ "io\\.grpc\\.internal\\.[^.]+"));

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -447,6 +464,86 @@ public void getEffectiveInterceptors_disableBoth() {
assertThat(effectiveInterceptors).containsExactly(DUMMY_USER_INTERCEPTOR);
}

@Test
public void getEffectiveInterceptors_callsGetGlobalInterceptors() throws Exception {
Class<?> runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsGet implements Runnable {

@Override
public void run() {
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors).hasSize(2);
try {
InternalGlobalInterceptors.setInterceptorsTracers(
Arrays.asList(DUMMY_USER_INTERCEPTOR),
Collections.emptyList(),
Collections.emptyList());
fail("exception expected");
} catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains("Set cannot be called after any get call");
}
}
}

@Test
public void getEffectiveInterceptors_callsSetGlobalInterceptors() throws Exception {
Class<?> runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSet.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsSet implements Runnable {

@Override
public void run() {
InternalGlobalInterceptors.setInterceptorsTracers(
Arrays.asList(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1),
Collections.emptyList(),
Collections.emptyList());
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors)
.containsExactly(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1);
}
}

@Test
public void getEffectiveInterceptors_setEmptyGlobalInterceptors() throws Exception {
Class<?> runnable =
classLoader.loadClass(StaticTestingClassLoaderCallsSetEmpty.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsSetEmpty implements Runnable {

@Override
public void run() {
InternalGlobalInterceptors.setInterceptorsTracers(
Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors).isEmpty();
}
}

@Test
public void idleTimeout() {
assertEquals(ManagedChannelImplBuilder.IDLE_MODE_DEFAULT_TIMEOUT_MILLIS,
Expand Down

0 comments on commit c9a52eb

Please sign in to comment.