Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api,core: change ManagedChannel and Server Builders to use GlobalInterceptors #9312

Merged
merged 2 commits into from Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 11 additions & 15 deletions api/src/main/java/io/grpc/GlobalInterceptors.java
Expand Up @@ -16,17 +16,19 @@

package io.grpc;

import static com.google.common.base.Preconditions.checkNotNull;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/** The collection of global interceptors and global server stream tracers. */
@Internal
final class GlobalInterceptors {
private static List<ClientInterceptor> clientInterceptors = Collections.emptyList();
private static List<ServerInterceptor> serverInterceptors = Collections.emptyList();
private static List<ClientInterceptor> clientInterceptors = null;
private static List<ServerInterceptor> serverInterceptors = null;
private static List<ServerStreamTracer.Factory> serverStreamTracerFactories =
Collections.emptyList();
null;
private static boolean isGlobalInterceptorsTracersSet;
private static boolean isGlobalInterceptorsTracersGet;

Expand Down Expand Up @@ -61,19 +63,13 @@ static synchronized void setInterceptorsTracers(
if (isGlobalInterceptorsTracersSet) {
throw new IllegalStateException("Global interceptors and tracers are already set");
}

if (clientInterceptorList != null) {
clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList));
}

if (serverInterceptorList != null) {
serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList));
}

if (serverStreamTracerFactoryList != null) {
serverStreamTracerFactories =
checkNotNull(clientInterceptorList);
checkNotNull(serverInterceptorList);
checkNotNull(serverStreamTracerFactoryList);
clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList));
serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList));
serverStreamTracerFactories =
Collections.unmodifiableList(new ArrayList<>(serverStreamTracerFactoryList));
}
isGlobalInterceptorsTracersSet = true;
}

Expand Down
9 changes: 5 additions & 4 deletions api/src/test/java/io/grpc/GlobalInterceptorsTest.java
Expand Up @@ -21,6 +21,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;
import org.junit.Test;
Expand Down Expand Up @@ -98,7 +99,7 @@ public static final class StaticTestingClassLoaderSetTwice implements Runnable {
public void run() {
GlobalInterceptors.setInterceptorsTracers(
new ArrayList<>(Arrays.asList(new NoopClientInterceptor())),
null,
Collections.emptyList(),
new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory())));
try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -115,7 +116,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetClientInterceptor
@Override
public void run() {
List<ClientInterceptor> clientInterceptors = GlobalInterceptors.getClientInterceptors();
assertThat(clientInterceptors).isEmpty();
assertThat(clientInterceptors).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -132,7 +133,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetServerInterceptor
@Override
public void run() {
List<ServerInterceptor> serverInterceptors = GlobalInterceptors.getServerInterceptors();
assertThat(serverInterceptors).isEmpty();
assertThat(serverInterceptors).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand All @@ -150,7 +151,7 @@ public static final class StaticTestingClassLoaderGetBeforeSetServerStreamTracer
public void run() {
List<ServerStreamTracer.Factory> serverStreamTracerFactories =
GlobalInterceptors.getServerStreamTracerFactories();
assertThat(serverStreamTracerFactories).isEmpty();
assertThat(serverStreamTracerFactories).isNull();

try {
GlobalInterceptors.setInterceptorsTracers(
Expand Down
15 changes: 11 additions & 4 deletions core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java
Expand Up @@ -31,6 +31,7 @@
import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalChannelz;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.NameResolver;
Expand Down Expand Up @@ -636,9 +637,15 @@ public ManagedChannel build() {
// TODO(zdapeng): FIX IT
@VisibleForTesting
List<ClientInterceptor> getEffectiveInterceptors() {
List<ClientInterceptor> effectiveInterceptors =
new ArrayList<>(this.interceptors);
if (statsEnabled) {
List<ClientInterceptor> effectiveInterceptors = new ArrayList<>(this.interceptors);
boolean isGlobalInterceptorsSet = false;
List<ClientInterceptor> globalClientInterceptors =
sanjaypujare marked this conversation as resolved.
Show resolved Hide resolved
InternalGlobalInterceptors.getClientInterceptors();
if (globalClientInterceptors != null) {
effectiveInterceptors.addAll(globalClientInterceptors);
isGlobalInterceptorsSet = true;
}
if (!isGlobalInterceptorsSet && statsEnabled) {
ClientInterceptor statsInterceptor = null;
try {
Class<?> censusStatsAccessor =
Expand Down Expand Up @@ -674,7 +681,7 @@ List<ClientInterceptor> getEffectiveInterceptors() {
effectiveInterceptors.add(0, statsInterceptor);
}
}
if (tracingEnabled) {
if (!isGlobalInterceptorsSet && tracingEnabled) {
ClientInterceptor tracingInterceptor = null;
try {
Class<?> censusTracingAccessor =
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/java/io/grpc/internal/ServerImplBuilder.java
Expand Up @@ -30,6 +30,7 @@
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.InternalChannelz;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCallExecutorSupplier;
Expand Down Expand Up @@ -246,7 +247,17 @@ public Server build() {
@VisibleForTesting
List<? extends ServerStreamTracer.Factory> getTracerFactories() {
ArrayList<ServerStreamTracer.Factory> tracerFactories = new ArrayList<>();
if (statsEnabled) {
boolean isGlobalInterceptorsTracersSet = false;
List<ServerInterceptor> globalServerInterceptors
= InternalGlobalInterceptors.getServerInterceptors();
List<ServerStreamTracer.Factory> globalServerStreamTracerFactories
= InternalGlobalInterceptors.getServerStreamTracerFactories();
if (globalServerInterceptors != null) {
tracerFactories.addAll(globalServerStreamTracerFactories);
interceptors.addAll(globalServerInterceptors);
isGlobalInterceptorsTracersSet = true;
}
if (!isGlobalInterceptorsTracersSet && statsEnabled) {
ServerStreamTracer.Factory censusStatsTracerFactory = null;
try {
Class<?> censusStatsAccessor =
Expand Down Expand Up @@ -278,7 +289,7 @@ List<? extends ServerStreamTracer.Factory> getTracerFactories() {
tracerFactories.add(censusStatsTracerFactory);
}
}
if (tracingEnabled) {
if (!isGlobalInterceptorsTracersSet && tracingEnabled) {
ServerStreamTracer.Factory tracingStreamTracerFactory = null;
try {
Class<?> censusTracingAccessor =
Expand Down
Expand Up @@ -35,9 +35,11 @@
import io.grpc.ClientInterceptor;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.InternalGlobalInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.NameResolver;
import io.grpc.StaticTestingClassLoader;
import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider;
import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder;
import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider;
Expand All @@ -47,12 +49,14 @@
import java.net.SocketAddress;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -78,6 +82,14 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
return next.newCall(method, callOptions);
}
};
private static final ClientInterceptor DUMMY_USER_INTERCEPTOR1 =
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return next.newCall(method, callOptions);
}
};

@Rule public final MockitoRule mocks = MockitoJUnit.rule();
@SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
Expand All @@ -90,7 +102,12 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
private ManagedChannelImplBuilder builder;
private ManagedChannelImplBuilder directAddressBuilder;
private final FakeClock clock = new FakeClock();

private final StaticTestingClassLoader classLoader =
new StaticTestingClassLoader(
getClass().getClassLoader(),
Pattern.compile(
"io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|"
+ "io\\.grpc\\.internal\\.[^.]+"));

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -447,6 +464,86 @@ public void getEffectiveInterceptors_disableBoth() {
assertThat(effectiveInterceptors).containsExactly(DUMMY_USER_INTERCEPTOR);
}

@Test
public void getEffectiveInterceptors_callsGetGlobalInterceptors() throws Exception {
Class<?> runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsGet implements Runnable {

@Override
public void run() {
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors).hasSize(2);
try {
InternalGlobalInterceptors.setInterceptorsTracers(
Arrays.asList(DUMMY_USER_INTERCEPTOR),
Collections.emptyList(),
Collections.emptyList());
fail("exception expected");
} catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains("Set cannot be called after any get call");
}
}
}

@Test
public void getEffectiveInterceptors_callsSetGlobalInterceptors() throws Exception {
Class<?> runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSet.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsSet implements Runnable {

@Override
public void run() {
InternalGlobalInterceptors.setInterceptorsTracers(
Arrays.asList(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1),
Collections.emptyList(),
Collections.emptyList());
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors)
.containsExactly(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1);
}
}

@Test
public void getEffectiveInterceptors_setEmptyGlobalInterceptors() throws Exception {
Class<?> runnable =
classLoader.loadClass(StaticTestingClassLoaderCallsSetEmpty.class.getName());
((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
}

// UsedReflectively
public static final class StaticTestingClassLoaderCallsSetEmpty implements Runnable {

@Override
public void run() {
InternalGlobalInterceptors.setInterceptorsTracers(
Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
ManagedChannelImplBuilder builder =
new ManagedChannelImplBuilder(
DUMMY_TARGET,
new UnsupportedClientTransportFactoryBuilder(),
new FixedPortProvider(DUMMY_PORT));
List<ClientInterceptor> effectiveInterceptors = builder.getEffectiveInterceptors();
assertThat(effectiveInterceptors).isEmpty();
}
}

@Test
public void idleTimeout() {
assertEquals(ManagedChannelImplBuilder.IDLE_MODE_DEFAULT_TIMEOUT_MILLIS,
Expand Down