Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds, rbac: build per route serverInterceptor for httpConfig #8524

Merged
merged 6 commits into from Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
161 changes: 106 additions & 55 deletions xds/src/main/java/io/grpc/xds/XdsServerWrapper.java
Expand Up @@ -22,6 +22,7 @@
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes;
import io.grpc.InternalServerInterceptors;
Expand Down Expand Up @@ -55,6 +56,7 @@
import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory;
import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener;
import io.grpc.xds.internal.sds.SslContextProviderSupplier;

ejona86 marked this conversation as resolved.
Show resolved Hide resolved
import java.io.IOException;
import java.net.SocketAddress;
import java.util.ArrayList;
Expand Down Expand Up @@ -345,6 +347,15 @@ private final class DiscoveryState implements LdsResourceWatcher {
@Nullable
private FilterChain defaultFilterChain;
private boolean stopped;
private final Map<FilterChain, AtomicReference<ImmutableMap<Route, ServerInterceptor>>>
rdsPrebuiltInterceptorRef = new HashMap<>();
private final ServerInterceptor noopInterceptor = new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers, ServerCallHandler<ReqT, RespT> next) {
return next.startCall(call, headers);
}
};

private DiscoveryState(String resourceName) {
this.resourceName = checkNotNull(resourceName, "resourceName");
Expand Down Expand Up @@ -452,6 +463,7 @@ private void shutdown() {

private void updateSelector() {
Map<FilterChain, ServerRoutingConfig> filterChainRouting = new HashMap<>();
rdsPrebuiltInterceptorRef.clear();
for (FilterChain filterChain: filterChains) {
filterChainRouting.put(filterChain, generateRoutingConfig(filterChain));
}
Expand All @@ -470,13 +482,77 @@ private void updateSelector() {
private ServerRoutingConfig generateRoutingConfig(FilterChain filterChain) {
HttpConnectionManager hcm = filterChain.getHttpConnectionManager();
if (hcm.virtualHosts() != null) {
return ServerRoutingConfig.create(hcm.httpFilterConfigs(),
new AtomicReference<>(hcm.virtualHosts()));
AtomicReference<ImmutableMap<Route, ServerInterceptor>> interceptorRef =
new AtomicReference<>(generatePerRouteInterceptors(
hcm.httpFilterConfigs(), hcm.virtualHosts()));
return ServerRoutingConfig.create(new AtomicReference<>(hcm.virtualHosts()),
interceptorRef);
} else {
RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName());
checkNotNull(rds, "rds");
return ServerRoutingConfig.create(hcm.httpFilterConfigs(), rds.savedVirtualHosts);
AtomicReference<ImmutableMap<Route, ServerInterceptor>> interceptorRef =
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
new AtomicReference<>(generatePerRouteInterceptors(
hcm.httpFilterConfigs(), rds.savedVirtualHosts.get()));
rdsPrebuiltInterceptorRef.put(filterChain, interceptorRef);
return ServerRoutingConfig.create(rds.savedVirtualHosts, interceptorRef);
}
}

private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors(
List<NamedFilterConfig> namedFilterConfigs, @Nullable List<VirtualHost> virtualHosts) {
if (virtualHosts == null) {
return null;
}
ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors =
new ImmutableMap.Builder<>();
for (VirtualHost virtualHost : virtualHosts) {
for (Route route : virtualHost.routes()) {
List<ServerInterceptor> filterInterceptors = new ArrayList<>();
Map<String, FilterConfig> selectedOverrideConfigs =
new HashMap<>(virtualHost.filterConfigOverrides());
selectedOverrideConfigs.putAll(route.filterConfigOverrides());
for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) {
FilterConfig filterConfig = namedFilterConfig.filterConfig;
Filter filter = filterRegistry.get(filterConfig.typeUrl());
if (filter instanceof ServerInterceptorBuilder) {
ServerInterceptor interceptor =
((ServerInterceptorBuilder) filter).buildServerInterceptor(
filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name));
if (interceptor != null) {
filterInterceptors.add(interceptor);
}
} else {
logger.log(Level.WARNING, "HttpFilterConfig(type URL: "
+ filterConfig.typeUrl() + ") is not supported on server-side. "
+ "Probably a bug at ClientXdsClient verification.");
}
}
ServerInterceptor interceptor = combineInterceptors(filterInterceptors);
perRouteInterceptors.put(route, interceptor);
}
}
return perRouteInterceptors.build();
}

private ServerInterceptor combineInterceptors(final List<ServerInterceptor> interceptors) {
if (interceptors.isEmpty()) {
return noopInterceptor;
}
if (interceptors.size() == 1) {
return interceptors.get(0);
}
return new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers, ServerCallHandler<ReqT, RespT> next) {
// intercept forward
for (int i = interceptors.size() - 1; i >= 0; i--) {
next = InternalServerInterceptors.interceptCallHandlerCreate(
interceptors.get(i), next);
}
return next.startCall(call, headers);
}
};
}

private void handleConfigNotFound(StatusException exception) {
Expand Down Expand Up @@ -507,6 +583,7 @@ private void cleanUpRouteDiscoveryStates() {
xdsClient.cancelRdsResourceWatch(rdsName, rdsState);
}
routeDiscoveryStates.clear();
rdsPrebuiltInterceptorRef.clear();
}

private List<SslContextProviderSupplier> getSuppliersInUse() {
Expand Down Expand Up @@ -560,6 +637,7 @@ public void run() {
return;
}
savedVirtualHosts.set(ImmutableList.copyOf(update.virtualHosts));
updateRdsPrebuiltInterceptorRef(update.virtualHosts);
maybeUpdateSelector();
}
});
Expand All @@ -575,6 +653,7 @@ public void run() {
}
logger.log(Level.WARNING, "Rds {0} unavailable", resourceName);
savedVirtualHosts.set(null);
updateRdsPrebuiltInterceptorRef(null);
maybeUpdateSelector();
}
});
Expand All @@ -595,6 +674,17 @@ public void run() {
});
}

private void updateRdsPrebuiltInterceptorRef(@Nullable List<VirtualHost> virtualHosts) {
for (FilterChain filterChain : rdsPrebuiltInterceptorRef.keySet()) {
if (resourceName.equals(filterChain.getHttpConnectionManager().rdsName())) {
ImmutableMap<Route, ServerInterceptor> updatedInterceptors =
generatePerRouteInterceptors(
filterChain.getHttpConnectionManager().httpFilterConfigs(), virtualHosts);
rdsPrebuiltInterceptorRef.get(filterChain).set(updatedInterceptors);
}
}
}

// Update the selector to use the most recently updated configs only after all rds have been
// discovered for the first time. Later changes on rds will be applied through virtual host
// list atomic ref.
Expand Down Expand Up @@ -652,14 +742,11 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
return new Listener<ReqT>() {};
}
Route selectedRoute = null;
Map<String, FilterConfig> selectedOverrideConfigs =
new HashMap<>(virtualHost.filterConfigOverrides());
MethodDescriptor<ReqT, RespT> method = call.getMethodDescriptor();
for (Route route : virtualHost.routes()) {
if (RoutingUtils.matchRoute(
route.routeMatch(), "/" + method.getFullMethodName(), headers, random)) {
selectedRoute = route;
selectedOverrideConfigs.putAll(route.filterConfigOverrides());
break;
}
}
Expand All @@ -669,48 +756,12 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
new Metadata());
return new ServerCall.Listener<ReqT>() {};
}
List<ServerInterceptor> filterInterceptors = new ArrayList<>();
for (NamedFilterConfig namedFilterConfig : routingConfig.httpFilterConfigs()) {
FilterConfig filterConfig = namedFilterConfig.filterConfig;
Filter filter = filterRegistry.get(filterConfig.typeUrl());
if (filter instanceof ServerInterceptorBuilder) {
ServerInterceptor interceptor =
((ServerInterceptorBuilder) filter).buildServerInterceptor(
filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name));
if (interceptor != null) {
filterInterceptors.add(interceptor);
}
} else {
call.close(
Status.UNAVAILABLE.withDescription("HttpFilterConfig(type URL: "
+ filterConfig.typeUrl() + ") is not supported on server-side."),
new Metadata());
return new Listener<ReqT>() {};
}
ServerInterceptor routeInterceptor = noopInterceptor;
Map<Route, ServerInterceptor> perRouteInterceptors = routingConfig.interceptors().get();
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
if (perRouteInterceptors != null && perRouteInterceptors.get(selectedRoute) != null) {
routeInterceptor = perRouteInterceptors.get(selectedRoute);
}
ServerInterceptor interceptor = combineInterceptors(filterInterceptors);
return interceptor.interceptCall(call, headers, next);
}

private ServerInterceptor combineInterceptors(final List<ServerInterceptor> interceptors) {
if (interceptors.isEmpty()) {
return noopInterceptor;
}
if (interceptors.size() == 1) {
return interceptors.get(0);
}
return new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
Metadata headers, ServerCallHandler<ReqT, RespT> next) {
// intercept forward
for (int i = interceptors.size() - 1; i >= 0; i--) {
next = InternalServerInterceptors.interceptCallHandlerCreate(
interceptors.get(i), next);
}
return next.startCall(call, headers);
}
};
return routeInterceptor.interceptCall(call, headers, next);
}
}

Expand All @@ -719,20 +770,20 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
*/
@AutoValue
abstract static class ServerRoutingConfig {
// Top level http filter configs.
abstract ImmutableList<NamedFilterConfig> httpFilterConfigs();

abstract AtomicReference<ImmutableList<VirtualHost>> virtualHosts();

// Prebuilt per route server interceptors from http filter configs.
abstract AtomicReference<ImmutableMap<Route, ServerInterceptor>> interceptors();

/**
* Server routing configuration.
* */
public static ServerRoutingConfig create(List<NamedFilterConfig> httpFilterConfigs,
AtomicReference<ImmutableList<VirtualHost>> virtualHosts) {
checkNotNull(httpFilterConfigs, "httpFilterConfigs");
public static ServerRoutingConfig create(
AtomicReference<ImmutableList<VirtualHost>> virtualHosts,
AtomicReference<ImmutableMap<Route, ServerInterceptor>> interceptors) {
checkNotNull(virtualHosts, "virtualHosts");
return new AutoValue_XdsServerWrapper_ServerRoutingConfig(
ImmutableList.copyOf(httpFilterConfigs), virtualHosts);
checkNotNull(interceptors, "interceptors");
return new AutoValue_XdsServerWrapper_ServerRoutingConfig(virtualHosts, interceptors);
}
}
}
Expand Up @@ -26,6 +26,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.ServerInterceptor;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
Expand Down Expand Up @@ -92,7 +93,8 @@ public class FilterChainMatchingProtocolNegotiatorsTest {
private static final String REMOTE_IP = "10.4.2.3"; // source
private static final int PORT = 7000;
private final ServerRoutingConfig noopConfig = ServerRoutingConfig.create(
new ArrayList<NamedFilterConfig>(), new AtomicReference<ImmutableList<VirtualHost>>());
new AtomicReference<ImmutableList<VirtualHost>>(),
new AtomicReference<ImmutableMap<Route, ServerInterceptor>>());

@Test
public void nofilterChainMatch_defaultSslContext() throws Exception {
Expand Down Expand Up @@ -239,9 +241,9 @@ public void destPortFails_returnDefaultFilterChain() throws Exception {
"filter-chain-bar", null, HTTP_CONNECTION_MANAGER,
tlsContextForDefaultFilterChain, tlsContextManager);

ServerRoutingConfig routingConfig = ServerRoutingConfig.create(
new ArrayList<NamedFilterConfig>(), new AtomicReference<>(
ImmutableList.of(createVirtualHost("virtual"))));
ServerRoutingConfig routingConfig = ServerRoutingConfig.create(new AtomicReference<>(
ImmutableList.of(createVirtualHost("virtual"))),
new AtomicReference<ImmutableMap<Route, ServerInterceptor>>());
FilterChainSelector selector = new FilterChainSelector(
ImmutableMap.of(filterChainWithDestPort, routingConfig),
defaultFilterChain.getSslContextProviderSupplier(), noopConfig);
Expand Down Expand Up @@ -1146,9 +1148,10 @@ private static VirtualHost createVirtualHost(String name) {
}

private static ServerRoutingConfig randomConfig(String domain) {
return ServerRoutingConfig.create(
new ArrayList<NamedFilterConfig>(), new AtomicReference<>(
ImmutableList.of(createVirtualHost(domain))));
return ServerRoutingConfig.create(new AtomicReference<>(
ImmutableList.of(createVirtualHost(domain))),
new AtomicReference<>(ImmutableMap.<Route, ServerInterceptor>of())
);
}

private EnvoyServerProtoData.DownstreamTlsContext createTls() {
Expand Down