Skip to content

Commit

Permalink
rls: overhall RouteLookupConfig validation
Browse files Browse the repository at this point in the history
  • Loading branch information
dapengzhang0 committed Nov 2, 2021
1 parent 59c6b49 commit 9f3db23
Show file tree
Hide file tree
Showing 11 changed files with 539 additions and 266 deletions.
17 changes: 17 additions & 0 deletions core/src/main/java/io/grpc/internal/JsonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ public static Long getStringAsDuration(Map<String, ?> obj, String key) {
}
}

/**
* Gets a string from an object for the given key, parsed as a long integer. If
* the key is not present, this returns null. If the value is not a String or not properly
* formatted, throws an exception.
*/
public static Long getStringAsLong(Map<String, ?> obj, String key) {
String value = getString(obj, key);
if (value == null) {
return null;
}
try {
return Long.valueOf(value);
} catch (NumberFormatException e) {
throw new RuntimeException(e);
}
}

/**
* Gets a boolean from an object for the given key. If the key is not present, this returns null.
* If the value is not a Boolean, throws an exception.
Expand Down
6 changes: 3 additions & 3 deletions rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ private CachingRlsLbClient(Builder builder) {
synchronizationContext = helper.getSynchronizationContext();
lbPolicyConfig = checkNotNull(builder.lbPolicyConfig, "lbPolicyConfig");
RouteLookupConfig rlsConfig = lbPolicyConfig.getRouteLookupConfig();
maxAgeNanos = TimeUnit.MILLISECONDS.toNanos(rlsConfig.getMaxAgeInMillis());
staleAgeNanos = TimeUnit.MILLISECONDS.toNanos(rlsConfig.getStaleAgeInMillis());
callTimeoutNanos = TimeUnit.MILLISECONDS.toNanos(rlsConfig.getLookupServiceTimeoutInMillis());
maxAgeNanos = rlsConfig.getMaxAgeInNanos();
staleAgeNanos = rlsConfig.getStaleAgeInNanos();
callTimeoutNanos = rlsConfig.getLookupServiceTimeoutInNanos();
timeProvider = checkNotNull(builder.timeProvider, "timeProvider");
throttler = checkNotNull(builder.throttler, "throttler");
linkedHashLruCache =
Expand Down
16 changes: 0 additions & 16 deletions rls/src/main/java/io/grpc/rls/RlsLoadBalancerProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,6 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawLoadBalanc
JsonUtil.getString(rawLoadBalancingConfigPolicy, "childPolicyConfigTargetFieldName"),
JsonUtil.checkObjectList(
checkNotNull(JsonUtil.getList(rawLoadBalancingConfigPolicy, "childPolicy"))));
// Checking all valid targets to make sure the config is always valid. This strict check
// prevents child policy to handle invalid child policy.
for (String validTarget : routeLookupConfig.getValidTargets()) {
ConfigOrError childPolicyConfigOrError =
lbPolicy
.getEffectiveLbProvider()
.parseLoadBalancingPolicyConfig(lbPolicy.getEffectiveChildPolicy(validTarget));
if (childPolicyConfigOrError.getError() != null) {
return
ConfigOrError.fromError(
childPolicyConfigOrError
.getError()
.augmentDescription(
"failed to parse childPolicy for validTarget: " + validTarget));
}
}
return ConfigOrError.fromConfig(new LbPolicyConfiguration(routeLookupConfig, lbPolicy));
} catch (Exception e) {
return ConfigOrError.fromError(
Expand Down
100 changes: 68 additions & 32 deletions rls/src/main/java/io/grpc/rls/RlsProtoConverters.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.common.base.Converter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import io.grpc.internal.JsonUtil;
import io.grpc.lookup.v1.RouteLookupRequest;
Expand All @@ -30,8 +32,10 @@
import io.grpc.rls.RlsProtoData.NameMatcher;
import io.grpc.rls.RlsProtoData.RouteLookupConfig;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

Expand All @@ -41,6 +45,9 @@
*/
final class RlsProtoConverters {

private static final long MAX_AGE_NANOS = TimeUnit.MINUTES.toNanos(5);
private static final long MAX_CACHE_SIZE = 5 * 1024 * 1024;

/**
* RouteLookupRequestConverter converts between {@link RouteLookupRequest} and {@link
* RlsProtoData.RouteLookupRequest}.
Expand Down Expand Up @@ -96,31 +103,47 @@ static final class RouteLookupConfigConverter
@Override
protected RouteLookupConfig doForward(Map<String, ?> json) {
List<GrpcKeyBuilder> grpcKeyBuilders =
GrpcKeyBuilderConverter
.covertAll(JsonUtil.checkObjectList(JsonUtil.getList(json, "grpcKeyBuilders")));
GrpcKeyBuilderConverter.covertAll(
checkNotNull(JsonUtil.getListOfObjects(json, "grpcKeyBuilders"), "grpcKeyBuilders"));
checkArgument(!grpcKeyBuilders.isEmpty(), "must have at least one GrpcKeyBuilder");
Set<Name> names = new HashSet<>();
for (GrpcKeyBuilder keyBuilder : grpcKeyBuilders) {
for (Name name : keyBuilder.getNames()) {
checkArgument(names.add(name), "duplicate names in grpc_keybuilders: " + name);
}
}
String lookupService = JsonUtil.getString(json, "lookupService");
long timeout =
TimeUnit.SECONDS.toMillis(
orDefault(
JsonUtil.getNumberAsLong(json, "lookupServiceTimeout"),
0L));
Long maxAge =
convertTimeIfNotNull(
TimeUnit.SECONDS, TimeUnit.MILLISECONDS, JsonUtil.getNumberAsLong(json, "maxAge"));
Long staleAge =
convertTimeIfNotNull(
TimeUnit.SECONDS, TimeUnit.MILLISECONDS, JsonUtil.getNumberAsLong(json, "staleAge"));
long cacheSize = orDefault(JsonUtil.getNumberAsLong(json, "cacheSizeBytes"), Long.MAX_VALUE);
List<String> validTargets = JsonUtil.checkStringList(JsonUtil.getList(json, "validTargets"));
// TODO(creamsoup) also check if it is URI
checkArgument(!Strings.isNullOrEmpty(lookupService), "lookupService must not be empty");
long timeout = orDefault(
JsonUtil.getStringAsDuration(json, "lookupServiceTimeout"),
SECONDS.toNanos(10));
checkArgument(timeout > 0, "lookupServiceTimeout should be positive");
Long maxAge = JsonUtil.getStringAsDuration(json, "maxAge");
Long staleAge = JsonUtil.getStringAsDuration(json, "staleAge");
if (maxAge == null) {
checkArgument(staleAge == null, "to specify staleAge, must have maxAge");
maxAge = MAX_AGE_NANOS;
}
if (staleAge == null) {
staleAge = MAX_AGE_NANOS;
}
maxAge = Math.min(maxAge, MAX_AGE_NANOS);
staleAge = Math.min(staleAge, maxAge);
long cacheSize = orDefault(JsonUtil.getStringAsLong(json, "cacheSizeBytes"), MAX_CACHE_SIZE);
checkArgument(cacheSize > 0, "cacheSize must be positive");
cacheSize = Math.min(cacheSize, MAX_CACHE_SIZE);
String defaultTarget = JsonUtil.getString(json, "defaultTarget");
if (Strings.isNullOrEmpty(defaultTarget)) {
defaultTarget = null;
}
return new RouteLookupConfig(
grpcKeyBuilders,
lookupService,
/* lookupServiceTimeoutInMillis= */ timeout,
/* maxAgeInMillis= */ maxAge,
/* staleAgeInMillis= */ staleAge,
/* lookupServiceTimeoutInNanos= */ timeout,
/* maxAgeInNanos= */ maxAge,
/* staleAgeInNanos= */ staleAge,
/* cacheSizeBytes= */ cacheSize,
validTargets,
defaultTarget);
}

Expand All @@ -131,13 +154,6 @@ private static <T> T orDefault(@Nullable T value, T defaultValue) {
return value;
}

private static Long convertTimeIfNotNull(TimeUnit from, TimeUnit to, Long value) {
if (value == null) {
return null;
}
return to.convert(value, from);
}

@Override
protected Map<String, Object> doBackward(RouteLookupConfig routeLookupConfig) {
throw new UnsupportedOperationException();
Expand All @@ -155,10 +171,16 @@ public static List<GrpcKeyBuilder> covertAll(List<Map<String, ?>> keyBuilders) {

@SuppressWarnings("unchecked")
static GrpcKeyBuilder convert(Map<String, ?> keyBuilder) {
List<?> rawRawNames = JsonUtil.getList(keyBuilder, "names");
checkArgument(
rawRawNames != null && !rawRawNames.isEmpty(),
"each keyBuilder must have at least one name");
List<Map<String, ?>> rawNames =
JsonUtil.checkObjectList(JsonUtil.getList(keyBuilder, "names"));
List<Name> names = new ArrayList<>();
for (Map<String, ?> rawName : rawNames) {
String serviceName = JsonUtil.getString(rawName, "service");
checkArgument(!Strings.isNullOrEmpty(serviceName), "service must not be empty or null");
names.add(
new Name(
JsonUtil.getString(rawName, "service"), JsonUtil.getString(rawName, "method")));
Expand All @@ -167,13 +189,12 @@ static GrpcKeyBuilder convert(Map<String, ?> keyBuilder) {
JsonUtil.checkObjectList(JsonUtil.getList(keyBuilder, "headers"));
List<NameMatcher> nameMatchers = new ArrayList<>();
for (Map<String, ?> rawHeader : rawHeaders) {
NameMatcher matcher =
new NameMatcher(
JsonUtil.getString(rawHeader, "key"),
(List<String>) rawHeader.get("names"),
(Boolean) rawHeader.get("optional"));
Boolean requiredMatch = JsonUtil.getBoolean(rawHeader, "requiredMatch");
checkArgument(
matcher.isOptional(), "NameMatcher for GrpcKeyBuilders shouldn't be required");
requiredMatch == null || !requiredMatch,
"requiredMatch shouldn't be specified for gRPC");
NameMatcher matcher = new NameMatcher(
JsonUtil.getString(rawHeader, "key"), (List<String>) rawHeader.get("names"));
nameMatchers.add(matcher);
}
ExtraKeys extraKeys = ExtraKeys.DEFAULT;
Expand All @@ -188,9 +209,24 @@ static GrpcKeyBuilder convert(Map<String, ?> keyBuilder) {
if (constantKeys == null) {
constantKeys = ImmutableMap.of();
}
checkUniqueKey(nameMatchers, constantKeys.keySet());
return new GrpcKeyBuilder(names, nameMatchers, extraKeys, constantKeys);
}
}

private static void checkUniqueKey(List<NameMatcher> nameMatchers, Set<String> constantKeys) {
Set<String> keys = new HashSet<>();
keys.addAll(constantKeys);
keys.add("host");
keys.add("service");
keys.add("method");
for (NameMatcher nameMatcher : nameMatchers) {
keys.add(nameMatcher.getKey());
}
if (keys.size() != nameMatchers.size() + constantKeys.size() + 3) {
throw new IllegalArgumentException("keys in KeyBuilder must be unique");
}
}

private RlsProtoConverters() {}
}

0 comments on commit 9f3db23

Please sign in to comment.