Skip to content

Commit

Permalink
Added token cache map to fix use of unintended auth token for subsequ…
Browse files Browse the repository at this point in the history
…ent connections (#2341)

* Added token cache map

* Added null check

* Removed print debug statements from tests

* Corrected return value

* Comments

* Applied formatting

* Increased cache TTL to 24 hrs

* Code review comments

* Code reivew comments p2
  • Loading branch information
tkyc committed Mar 20, 2024
1 parent aa46637 commit 1d4b7d6
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 11 deletions.
Expand Up @@ -22,13 +22,16 @@
* @see <a href="https://aka.ms/msal4j-token-cache">https://aka.ms/msal4j-token-cache</a>
*/
public class PersistentTokenCacheAccessAspect implements ITokenCacheAccessAspect {
private static PersistentTokenCacheAccessAspect instance = new PersistentTokenCacheAccessAspect();

private static PersistentTokenCacheAccessAspect instance;
private final Lock lock = new ReentrantLock();

private PersistentTokenCacheAccessAspect() {}
static final long TIME_TO_LIVE = 86400000L; // Token cache time to live (24 hrs).
private long expiryTime;

static PersistentTokenCacheAccessAspect getInstance() {
if (instance == null) {
instance = new PersistentTokenCacheAccessAspect();
}
return instance;
}

Expand Down Expand Up @@ -62,6 +65,14 @@ public void afterCacheAccess(ITokenCacheAccessContext iTokenCacheAccessContext)

}

public long getExpiryTime() {
return this.expiryTime;
}

public void setExpiryTime(long expiryTime) {
this.expiryTime = expiryTime;
}

/**
* Clears User token cache. This will clear all account info so interactive login will be required on the next
* request to acquire an access token.
Expand Down
Expand Up @@ -14,13 +14,15 @@
import java.net.URI;
import java.net.URISyntaxException;

import java.security.MessageDigest;
import java.text.MessageFormat;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -61,6 +63,8 @@ class SQLServerMSAL4JUtils {
static final String SLASH_DEFAULT = "/.default";
static final String ACCESS_TOKEN_EXPIRE = "access token expires: ";

private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap();

private final static String LOGCONTEXT = "MSAL version "
+ com.microsoft.aad.msal4j.PublicClientApplication.class.getPackage().getImplementationVersion() + ": ";

Expand All @@ -84,10 +88,17 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str
lock.lock();

try {
String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret);

if (null == persistentTokenCacheAccessAspect) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}

final PublicClientApplication pca = PublicClientApplication
.builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.authority(fedAuthInfo.stsurl).build();
.setTokenCacheAccessAspect(persistentTokenCacheAccessAspect).authority(fedAuthInfo.stsurl).build();

final CompletableFuture<IAuthenticationResult> future = pca.acquireToken(UserNamePasswordParameters
.builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray())
Expand Down Expand Up @@ -132,11 +143,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth
lock.lock();

try {
String hashedSecret = getHashedSecret(
new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret);

if (null == persistentTokenCacheAccessAspect) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}

IClientCredential credential = ClientCredentialFactory.createFromSecret(aadPrincipalSecret);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(aadPrincipalID, credential).executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.authority(fedAuthInfo.stsurl).build();
.setTokenCacheAccessAspect(persistentTokenCacheAccessAspect).authority(fedAuthInfo.stsurl).build();

final CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());
Expand Down Expand Up @@ -181,6 +200,15 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI
lock.lock();

try {
String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile,
certPassword, certKey, certKeyPassword});
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(hashedSecret);

if (null == persistentTokenCacheAccessAspect) {
persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
TOKEN_CACHE_MAP.addEntry(hashedSecret, persistentTokenCacheAccessAspect);
}

ConfidentialClientApplication clientApplication = null;

// check if cert is PKCS12 first
Expand All @@ -202,8 +230,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI

IClientCredential credential = ClientCredentialFactory.createFromCertificate(is, certPassword);
clientApplication = ConfidentialClientApplication.builder(aadPrincipalID, credential)
.executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.executorService(executorService).setTokenCacheAccessAspect(persistentTokenCacheAccessAspect)
.authority(fedAuthInfo.stsurl).build();
} catch (FileNotFoundException e) {
// re-throw if file not there no point to try another format
Expand Down Expand Up @@ -232,8 +259,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI

IClientCredential credential = ClientCredentialFactory.createFromCertificate(privateKey, cert);
clientApplication = ConfidentialClientApplication.builder(aadPrincipalID, credential)
.executorService(executorService)
.setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
.executorService(executorService).setTokenCacheAccessAspect(persistentTokenCacheAccessAspect)
.authority(fedAuthInfo.stsurl).build();
}

Expand Down Expand Up @@ -449,4 +475,45 @@ private static SQLServerException getCorrectedException(Exception e, String user
return new SQLServerException(form.format(msgArgs), null, 0, correctedExecutionException);
}
}

private static String getHashedSecret(String[] secrets) throws SQLServerException {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
for (String secret : secrets) {
if (null != secret) {
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
}
}
return new String(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
}
}

private static class TokenCacheMap {
private ConcurrentHashMap<String, PersistentTokenCacheAccessAspect> tokenCacheMap = new ConcurrentHashMap<>();

PersistentTokenCacheAccessAspect getEntry(String key) {
PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = tokenCacheMap.get(key);

if (null != persistentTokenCacheAccessAspect) {
if (System.currentTimeMillis() > persistentTokenCacheAccessAspect.getExpiryTime()) {
tokenCacheMap.remove(key);

persistentTokenCacheAccessAspect = new PersistentTokenCacheAccessAspect();
persistentTokenCacheAccessAspect
.setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE);

tokenCacheMap.put(key, persistentTokenCacheAccessAspect);
}
}

return persistentTokenCacheAccessAspect;
}

void addEntry(String key, PersistentTokenCacheAccessAspect value) {
value.setExpiryTime(System.currentTimeMillis() + PersistentTokenCacheAccessAspect.TIME_TO_LIVE);
tokenCacheMap.put(key, value);
}
}
}
5 changes: 5 additions & 0 deletions src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Expand Up @@ -63,6 +63,11 @@ protected Object[][] getContents() {
{"R_ConnectionURLNull", "The connection URL is null."},
{"R_connectionIsNotClosed", "The connection is not closed."},
{"R_invalidExceptionMessage", "Invalid exception message"},
{"R_invalidClientSecret", "AADSTS7000215: Invalid client secret provided"},
{"R_invalidCertFields",
"Error reading certificate, please verify the location of the certificate.signed fields invalid"},
{"R_invalidAADAuth",
"Failed to authenticate the user {0} in Active Directory (Authentication={1})"},
{"R_failedValidate", "failed to validate values in $0} "}, {"R_tableNotDropped", "table not dropped. "},
{"R_connectionReset", "Connection reset"}, {"R_unknownException", "Unknown exception"},
{"R_deadConnection", "Dead connection should be invalid"},
Expand Down
Expand Up @@ -18,6 +18,7 @@
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.MessageFormat;
import java.util.Collections;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -323,6 +324,66 @@ public void testAADServicePrincipalAuth() {
}
}

@Test
public void testAADServicePrincipalAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidSecret() throws Exception {
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
+ applicationKey;

String invalidSecretUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
+ "invalidSecret";

// Should succeed on valid secret
try (Connection connection = DriverManager.getConnection(url)) {}

// Should fail on invalid secret
try (Connection connection = DriverManager.getConnection(invalidSecretUrl)) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidClientSecret")),
"Expected R_invalidClientSecret error.");
}
}

@Test
public void testActiveDirectoryPasswordFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidPassword() throws Exception {

// Should succeed on valid password
try (Connection conn = DriverManager.getConnection(adPasswordConnectionStr)) {}

// Should fail on invalid password
try (Connection conn = DriverManager.getConnection(adPasswordConnectionStr + ";password=invalidPassword;")) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
MessageFormat form = new MessageFormat(TestResource.getResource("R_invalidAADAuth"));
Object[] msgArgs = {azureUserName, "ActiveDirectoryPassword"};
assertTrue(e.getMessage().contains(form.format(msgArgs)), "Expected R_invalidAADAuth error.");
}
}

@Test
public void testAADServicePrincipalCertAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidPassword() throws Exception {
// Should succeed on valid cert field values
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipalCertificate + ";Username=" + applicationClientID
+ ";password=" + certificatePassword + ";clientCertificate=" + clientCertificate;

// Should fail on invalid cert field values
String invalidPasswordUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase
+ ";authentication=" + SqlAuthentication.ActiveDirectoryServicePrincipalCertificate + ";Username="
+ applicationClientID + ";password=invalidPassword;clientCertificate=" + clientCertificate;

try (Connection conn = DriverManager.getConnection(url)) {}

try (Connection conn = DriverManager.getConnection(invalidPasswordUrl)) {
fail(TestResource.getResource("R_expectedFailPassed"));
} catch (Exception e) {
assertTrue(e.getMessage().contains(TestResource.getResource("R_invalidCertFields")),
"Expected R_invalidCertFields error.");
}
}

/**
* Test invalid connection property combinations when using AAD Service Principal Authentication.
*/
Expand Down

0 comments on commit 1d4b7d6

Please sign in to comment.