diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java index 879cc69167c..0fd5430ba61 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -29,7 +29,10 @@ import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; +import java.util.ArrayList; import java.util.Collection; +import java.util.Iterator; +import java.util.List; /** Static factory methods for creating standard security policies. */ @CheckReturnValue @@ -218,6 +221,60 @@ public Status checkAuthorization(int uid) { }; } + /** + * Creates a {@link SecurityPolicy} that allows access if *any* of the specified {@code + * securityPolicies} allow access. + * + *

Policies will be checked in the order that they are passed. If a policy allows access, + * subsequent policies will not be checked. + * + *

If all policies deny access, the {@link io.grpc.Status} returned by {@code + * checkAuthorization} will included the concatenated descriptions of the failed policies and + * attach any additional causes as suppressed throwables. The status code will be that of the + * first failed policy. + * + * @param securityPolicies the security policies that will be checked. + * @throws NullPointerException if any of the inputs are {@code null}. + * @throws IllegalArgumentException if {@code securityPolicies} is empty. + */ + public static SecurityPolicy anyOf(SecurityPolicy... securityPolicies) { + Preconditions.checkNotNull(securityPolicies, "securityPolicies"); + Preconditions.checkArgument(securityPolicies.length > 0, "securityPolicies must not be empty"); + + return anyOfSecurityPolicy(securityPolicies); + } + + private static SecurityPolicy anyOfSecurityPolicy(SecurityPolicy... securityPolicies) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + List failed = new ArrayList<>(); + for (SecurityPolicy policy : securityPolicies) { + Status checkAuth = policy.checkAuthorization(uid); + if (checkAuth.isOk()) { + return checkAuth; + } + failed.add(checkAuth); + } + + Iterator iter = failed.iterator(); + Status toReturn = iter.next(); + while (iter.hasNext()) { + Status append = iter.next(); + toReturn = toReturn.augmentDescription(append.getDescription()); + if (append.getCause() != null) { + if (toReturn.getCause() != null) { + toReturn.getCause().addSuppressed(append.getCause()); + } else { + toReturn = toReturn.withCause(append.getCause()); + } + } + } + return toReturn; + } + }; + } + /** * Creates a {@link SecurityPolicy} which checks if the caller has all of the given permissions * from {@code permissions}. diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java index bff414db60f..a17162325e6 100644 --- a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java +++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java @@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableSet; import io.grpc.Status; import java.util.HashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -392,4 +393,46 @@ public void testAllOf_failsIfOneSecurityPoliciesNotAllowed() throws Exception { assertThat(policy.checkAuthorization(MY_UID).getDescription()) .contains("Not allowed SecurityPolicy"); } + + @Test + public void testAnyOf_succeedsIfAnySecurityPoliciesAllowed() throws Exception { + RecordingPolicy recordingPolicy = new RecordingPolicy(); + policy = SecurityPolicies.anyOf(SecurityPolicies.internalOnly(), recordingPolicy); + + assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode()); + assertThat(recordingPolicy.numCalls.get()).isEqualTo(0); + } + + @Test + public void testAnyOf_failsIfNoSecurityPolicyIsAllowed() throws Exception { + policy = + SecurityPolicies.anyOf( + new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return Status.PERMISSION_DENIED.withDescription("Not allowed: first"); + } + }, + new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return Status.UNAUTHENTICATED.withDescription("Not allowed: second"); + } + }); + + assertThat(policy.checkAuthorization(MY_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(MY_UID).getDescription()).contains("Not allowed: first"); + assertThat(policy.checkAuthorization(MY_UID).getDescription()).contains("Not allowed: second"); + } + + private static final class RecordingPolicy extends SecurityPolicy { + private final AtomicInteger numCalls = new AtomicInteger(0); + + @Override + public Status checkAuthorization(int uid) { + numCalls.incrementAndGet(); + return Status.OK; + } + } }