Skip to content

Commit

Permalink
Add config property for custom claim verification
Browse files Browse the repository at this point in the history
Co-authored-by: sberyozkin <sberyozkin@gmail.com>
  • Loading branch information
djnalluri and sberyozkin committed Aug 16, 2022
1 parent 6ad5d86 commit a7b9c2e
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
Expand Up @@ -2,6 +2,7 @@

import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -10,6 +11,7 @@
import io.quarkus.oidc.common.runtime.OidcCommonConfig;
import io.quarkus.oidc.common.runtime.OidcConstants;
import io.quarkus.oidc.runtime.OidcConfig;
import io.quarkus.runtime.annotations.ConfigDocMapKey;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.quarkus.runtime.annotations.ConfigItem;

Expand Down Expand Up @@ -950,6 +952,17 @@ public static Token fromAudience(String... audience) {
@ConfigItem
public Optional<List<String>> audience = Optional.empty();

/**
* A map of required claims and their expected values.
* For example, `quarkus.oidc.token.required-claims.org_id = org_xyz` would require tokens to have the `org_id` claim to
* be present and set to `org_xyz`.
* Strings are the only supported types. Use {@linkplain SecurityIdentityAugmentor} to verify claims of other types or
* complex claims.
*/
@ConfigItem
@ConfigDocMapKey("claim-name")
public Map<String, String> requiredClaims = new HashMap<>();

/**
* Expected token type
*/
Expand Down Expand Up @@ -1167,6 +1180,14 @@ public Optional<String> getDecryptionKeyLocation() {
public void setDecryptionKeyLocation(String decryptionKeyLocation) {
this.decryptionKeyLocation = Optional.of(decryptionKeyLocation);
}

public Map<String, String> getRequiredClaims() {
return requiredClaims;
}

public void setRequiredClaims(Map<String, String> requiredClaims) {
this.requiredClaims = requiredClaims;
}
}

public static enum ApplicationType {
Expand Down
Expand Up @@ -4,18 +4,22 @@
import java.security.Key;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import org.eclipse.microprofile.jwt.Claims;
import org.jboss.logging.Logger;
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.ErrorCodeValidator;
import org.jose4j.jwt.consumer.ErrorCodes;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;
import org.jose4j.jwt.consumer.Validator;
import org.jose4j.jwx.HeaderParameterNames;
import org.jose4j.jwx.JsonWebStructure;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
Expand Down Expand Up @@ -53,6 +57,7 @@ public class OidcProvider implements Closeable {
final OidcTenantConfig oidcConfig;
final String issuer;
final String[] audience;
final Map<String, String> requiredClaims;
final Key tokenDecryptionKey;

public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, JsonWebKeySet jwks, Key tokenDecryptionKey) {
Expand All @@ -63,6 +68,7 @@ public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, Json

this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.requiredClaims = checkRequiredClaimsProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

Expand All @@ -72,6 +78,7 @@ public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig, Key tokenD
this.asymmetricKeyResolver = new LocalPublicKeyResolver(publicKeyEnc);
this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.requiredClaims = checkRequiredClaimsProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

Expand All @@ -91,6 +98,10 @@ private String[] checkAudienceProp() {
return audienceProp != null ? audienceProp.toArray(new String[] {}) : null;
}

private Map<String, String> checkRequiredClaimsProp() {
return oidcConfig != null ? oidcConfig.token.requiredClaims : null;
}

public TokenVerificationResult verifySelfSignedJwtToken(String token) throws InvalidJwtException {
return verifyJwtTokenInternal(token, SYMMETRIC_ALGORITHM_CONSTRAINTS, new SymmetricKeyResolver(), true);
}
Expand Down Expand Up @@ -135,6 +146,9 @@ private TokenVerificationResult verifyJwtTokenInternal(String token, AlgorithmCo
} else {
builder.setSkipDefaultAudienceValidation();
}
if (requiredClaims != null) {
builder.registerValidator(new CustomClaimsValidator(requiredClaims));
}

if (oidcConfig.token.lifespanGrace.isPresent()) {
final int lifespanGrace = oidcConfig.token.lifespanGrace.getAsInt();
Expand Down Expand Up @@ -383,4 +397,33 @@ default Uni<Void> refresh() {
return Uni.createFrom().voidItem();
}
}

private static class CustomClaimsValidator implements Validator {

private final Map<String, String> customClaims;

public CustomClaimsValidator(Map<String, String> customClaims) {
this.customClaims = customClaims;
}

@Override
public String validate(JwtContext jwtContext) throws MalformedClaimException {
var claims = jwtContext.getJwtClaims();
for (var targetClaim : customClaims.entrySet()) {
var claimName = targetClaim.getKey();
if (!claims.hasClaim(claimName)) {
return "claim " + claimName + " is missing";
}
if (!claims.isClaimValueString(claimName)) {
throw new MalformedClaimException("expected claim " + claimName + " to be a string");
}
var claimValue = claims.getStringClaimValue(claimName);
var targetValue = targetClaim.getValue();
if (!claimValue.equals(targetValue)) {
return "claim " + claimName + "does not match expected value of " + targetValue;
}
}
return null;
}
}
}
Expand Up @@ -95,6 +95,10 @@ quarkus.oidc.tenant-customheader.credentials.secret=secret
quarkus.oidc.tenant-customheader.token.header=X-Forwarded-Authorization
quarkus.oidc.tenant-customheader.application-type=service

# Required claim (Uses tenant-b settings as it has multiple clients)
quarkus.oidc.tenant-requiredclaim.auth-server-url=${keycloak.url}/realms/quarkus-b
quarkus.oidc.tenant-requiredclaim.application-type=service
quarkus.oidc.tenant-requiredclaim.token.required-claims.azp=quarkus-app-b

quarkus.oidc.tenant-public-key.client-id=test
quarkus.oidc.tenant-public-key.public-key=MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAlivFI8qB4D0y2jy0CfEqFyy46R0o7S8TKpsx5xbHKoU1VWg6QkQm+ntyIv1p4kE1sPEQO73+HY8+Bzs75XwRTYL1BmR1w8J5hmjVWjc6R2BTBGAYRPFRhor3kpM6ni2SPmNNhurEAHw7TaqszP5eUF/F9+KEBWkwVta+PZ37bwqSE4sCb1soZFrVz/UT/LF4tYpuVYt3YbqToZ3pZOZ9AX2o1GCG3xwOjkc4x0W7ezbQZdC9iftPxVHR8irOijJRRjcPDtA6vPKpzLl6CyYnsIYPd99ltwxTHjr3npfv/3Lw50bAkbT4HeLFxTx4flEoZLKO/g0bAoV2uqBhkA9xnQIDAQAB
Expand Down
Expand Up @@ -499,6 +499,24 @@ public void testResolveTenantIdentifierWebAppDynamic() throws IOException {
}
}

@Test
public void testRequiredClaimPass() {
//Client id should match the required azp claim
RestAssured.given().auth().oauth2(getAccessToken("alice", "b", "b"))
.when().get("/tenant/tenant-requiredclaim/api/user")
.then()
.statusCode(200);
}

@Test
public void testRequiredClaimFail() {
//Client id does not match required azp claim
RestAssured.given().auth().oauth2(getAccessToken("alice", "b", "b2"))
.when().get("/tenant/tenant-requiredclaim/api/user")
.then()
.statusCode(401);
}

private String getAccessToken(String userName, String clientId) {
return getAccessToken(userName, clientId, clientId);
}
Expand Down

0 comments on commit a7b9c2e

Please sign in to comment.