Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom OIDC claim verification #27292

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -2,6 +2,7 @@

import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -10,6 +11,7 @@
import io.quarkus.oidc.common.runtime.OidcCommonConfig;
import io.quarkus.oidc.common.runtime.OidcConstants;
import io.quarkus.oidc.runtime.OidcConfig;
import io.quarkus.runtime.annotations.ConfigDocMapKey;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.quarkus.runtime.annotations.ConfigItem;

Expand Down Expand Up @@ -950,6 +952,17 @@ public static Token fromAudience(String... audience) {
@ConfigItem
public Optional<List<String>> audience = Optional.empty();

/**
* A map of required claims and their expected values.
* For example, `quarkus.oidc.token.required-claims.org_id = org_xyz` would require tokens to have the `org_id` claim to
* be present and set to `org_xyz`.
* Strings are the only supported types. Use {@linkplain SecurityIdentityAugmentor} to verify claims of other types or
* complex claims.
*/
@ConfigItem
@ConfigDocMapKey("claim-name")
public Map<String, String> requiredClaims = new HashMap<>();

/**
* Expected token type
*/
Expand Down Expand Up @@ -1167,6 +1180,14 @@ public Optional<String> getDecryptionKeyLocation() {
public void setDecryptionKeyLocation(String decryptionKeyLocation) {
this.decryptionKeyLocation = Optional.of(decryptionKeyLocation);
}

public Map<String, String> getRequiredClaims() {
return requiredClaims;
}

public void setRequiredClaims(Map<String, String> requiredClaims) {
this.requiredClaims = requiredClaims;
}
}

public static enum ApplicationType {
Expand Down
Expand Up @@ -4,18 +4,22 @@
import java.security.Key;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import org.eclipse.microprofile.jwt.Claims;
import org.jboss.logging.Logger;
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.ErrorCodeValidator;
import org.jose4j.jwt.consumer.ErrorCodes;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.jwt.consumer.JwtContext;
import org.jose4j.jwt.consumer.Validator;
import org.jose4j.jwx.HeaderParameterNames;
import org.jose4j.jwx.JsonWebStructure;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
Expand Down Expand Up @@ -53,6 +57,7 @@ public class OidcProvider implements Closeable {
final OidcTenantConfig oidcConfig;
final String issuer;
final String[] audience;
final Map<String, String> requiredClaims;
final Key tokenDecryptionKey;

public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, JsonWebKeySet jwks, Key tokenDecryptionKey) {
Expand All @@ -63,6 +68,7 @@ public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, Json

this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.requiredClaims = checkRequiredClaimsProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

Expand All @@ -72,6 +78,7 @@ public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig, Key tokenD
this.asymmetricKeyResolver = new LocalPublicKeyResolver(publicKeyEnc);
this.issuer = checkIssuerProp();
this.audience = checkAudienceProp();
this.requiredClaims = checkRequiredClaimsProp();
this.tokenDecryptionKey = tokenDecryptionKey;
}

Expand All @@ -91,6 +98,10 @@ private String[] checkAudienceProp() {
return audienceProp != null ? audienceProp.toArray(new String[] {}) : null;
}

private Map<String, String> checkRequiredClaimsProp() {
return oidcConfig != null ? oidcConfig.token.requiredClaims : null;
}

public TokenVerificationResult verifySelfSignedJwtToken(String token) throws InvalidJwtException {
return verifyJwtTokenInternal(token, SYMMETRIC_ALGORITHM_CONSTRAINTS, new SymmetricKeyResolver(), true);
}
Expand Down Expand Up @@ -135,6 +146,9 @@ private TokenVerificationResult verifyJwtTokenInternal(String token, AlgorithmCo
} else {
builder.setSkipDefaultAudienceValidation();
}
if (requiredClaims != null) {
builder.registerValidator(new CustomClaimsValidator(requiredClaims));
}

if (oidcConfig.token.lifespanGrace.isPresent()) {
final int lifespanGrace = oidcConfig.token.lifespanGrace.getAsInt();
Expand Down Expand Up @@ -383,4 +397,33 @@ default Uni<Void> refresh() {
return Uni.createFrom().voidItem();
}
}

private static class CustomClaimsValidator implements Validator {

private final Map<String, String> customClaims;

public CustomClaimsValidator(Map<String, String> customClaims) {
this.customClaims = customClaims;
}

@Override
public String validate(JwtContext jwtContext) throws MalformedClaimException {
var claims = jwtContext.getJwtClaims();
for (var targetClaim : customClaims.entrySet()) {
var claimName = targetClaim.getKey();
if (!claims.hasClaim(claimName)) {
return "claim " + claimName + " is missing";
}
if (!claims.isClaimValueString(claimName)) {
throw new MalformedClaimException("expected claim " + claimName + " to be a string");
}
var claimValue = claims.getStringClaimValue(claimName);
var targetValue = targetClaim.getValue();
if (!claimValue.equals(targetValue)) {
return "claim " + claimName + "does not match expected value of " + targetValue;
}
}
return null;
}
}
}
Expand Up @@ -95,6 +95,10 @@ quarkus.oidc.tenant-customheader.credentials.secret=secret
quarkus.oidc.tenant-customheader.token.header=X-Forwarded-Authorization
quarkus.oidc.tenant-customheader.application-type=service

# Required claim (Uses tenant-b settings as it has multiple clients)
quarkus.oidc.tenant-requiredclaim.auth-server-url=${keycloak.url}/realms/quarkus-b
quarkus.oidc.tenant-requiredclaim.application-type=service
quarkus.oidc.tenant-requiredclaim.token.required-claims.azp=quarkus-app-b

quarkus.oidc.tenant-public-key.client-id=test
quarkus.oidc.tenant-public-key.public-key=MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAlivFI8qB4D0y2jy0CfEqFyy46R0o7S8TKpsx5xbHKoU1VWg6QkQm+ntyIv1p4kE1sPEQO73+HY8+Bzs75XwRTYL1BmR1w8J5hmjVWjc6R2BTBGAYRPFRhor3kpM6ni2SPmNNhurEAHw7TaqszP5eUF/F9+KEBWkwVta+PZ37bwqSE4sCb1soZFrVz/UT/LF4tYpuVYt3YbqToZ3pZOZ9AX2o1GCG3xwOjkc4x0W7ezbQZdC9iftPxVHR8irOijJRRjcPDtA6vPKpzLl6CyYnsIYPd99ltwxTHjr3npfv/3Lw50bAkbT4HeLFxTx4flEoZLKO/g0bAoV2uqBhkA9xnQIDAQAB
Expand Down
Expand Up @@ -499,6 +499,24 @@ public void testResolveTenantIdentifierWebAppDynamic() throws IOException {
}
}

@Test
public void testRequiredClaimPass() {
//Client id should match the required azp claim
RestAssured.given().auth().oauth2(getAccessToken("alice", "b", "b"))
.when().get("/tenant/tenant-requiredclaim/api/user")
.then()
.statusCode(200);
}

@Test
public void testRequiredClaimFail() {
//Client id does not match required azp claim
RestAssured.given().auth().oauth2(getAccessToken("alice", "b", "b2"))
.when().get("/tenant/tenant-requiredclaim/api/user")
.then()
.statusCode(401);
}

private String getAccessToken(String userName, String clientId) {
return getAccessToken(userName, clientId, clientId);
}
Expand Down