Skip to content

Commit

Permalink
Managed Identity token cache (#1825)
Browse files Browse the repository at this point in the history
  • Loading branch information
David-Engel committed May 26, 2022
1 parent bff2f4b commit 7ffc2f0
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial
/** encrypted truststore password */
byte[] encryptedTrustStorePassword = null;

/** cached MSI token time-to-live */
private int cachedMsiTokenTtl = 0;

/**
* Return an existing cached SharedTimer associated with this Connection or create a new one.
*
Expand Down Expand Up @@ -2636,6 +2639,26 @@ else if (0 == requestedPacketSize)
activeConnectionProperties.setProperty(sPropKey, sPropValue);
}

cachedMsiTokenTtl = SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.getDefaultValue();
sPropValue = activeConnectionProperties
.getProperty(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.toString());
if (null != sPropValue && sPropValue.length() > 0) {
try {
cachedMsiTokenTtl = Integer.parseInt(sPropValue);
} catch (NumberFormatException e) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_invalidMsiTokenCacheTtl"));
Object[] msgArgs = {sPropValue};
SQLServerException.makeFromDriverError(this, this, form.format(msgArgs), null, false);
}
if (cachedMsiTokenTtl < 0) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_invalidMsiTokenCacheTtl"));
Object[] msgArgs = {sPropValue};
SQLServerException.makeFromDriverError(this, this, form.format(msgArgs), null, false);
}
}

sPropKey = SQLServerDriverStringProperty.CLIENT_CERTIFICATE.toString();
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null != sPropValue) {
Expand Down Expand Up @@ -5405,7 +5428,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe

String user = activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString());

// No:of milliseconds to sleep for the inital back off.
// No:of milliseconds to sleep for the initial back off.
int sleepInterval = 100;

while (true) {
Expand All @@ -5422,7 +5445,8 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
break;
} else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) {
fedAuthToken = SQLServerSecurityUtility.getMSIAuthToken(fedAuthInfo.spn,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()),
cachedMsiTokenTtl);

// Break out of the retry loop in successful case.
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ enum SQLServerDriverIntProperty {
STATEMENT_POOLING_CACHE_SIZE("statementPoolingCacheSize", SQLServerConnection.DEFAULT_STATEMENT_POOLING_CACHE_SIZE),
CANCEL_QUERY_TIMEOUT("cancelQueryTimeout", -1),
CONNECT_RETRY_COUNT("connectRetryCount", 1, 0, 255),
CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60);
CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60),
MSI_TOKEN_CACHE_TTL("msiTokenCacheTtl", 3600, 0, Integer.MAX_VALUE);

private final String name;
private final int defaultValue;
Expand Down Expand Up @@ -775,6 +776,9 @@ public final class SQLServerDriver implements java.sql.Driver {
false, TRUE_FALSE),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(),
SQLServerDriverStringProperty.MSI_CLIENT_ID.getDefaultValue(), false, null),
new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.toString(),
Integer.toString(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.getDefaultValue()), false,
null),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_ID.toString(),
SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_ID.getDefaultValue(), false, null),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_KEY.toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ protected Object[][] getContents() {
{"R_lockTimeoutPropertyDescription", "The number of milliseconds to wait before the database reports a lock time-out."},
{"R_connectRetryCountPropertyDescription", "The number of reconnection attempts if there is a connection failure."},
{"R_connectRetryIntervalPropertyDescription", "The number of seconds between each connection retry attempt."},
{"R_msiTokenCacheTtlPropertyDescription", "The number of seconds a Managed Identity (MSI) access token should be cached."},
{"R_loginTimeoutPropertyDescription", "The number of seconds the driver should wait before timing out a failed connection."},
{"R_instanceNamePropertyDescription", "The name of the SQL Server instance to connect to."},
{"R_xopenStatesPropertyDescription", "Determines if the driver returns XOPEN-compliant SQL state codes in exceptions."},
Expand Down Expand Up @@ -499,6 +500,7 @@ protected Object[][] getContents() {
{"R_maxResultBufferInvalidSyntax", "Invalid syntax: {0} in maxResultBuffer parameter."},
{"R_maxResultBufferNegativeParameterValue", "MaxResultBuffer must have positive value: {0}."},
{"R_maxResultBufferPropertyExceeded", "MaxResultBuffer property exceeded: {0}. MaxResultBuffer was set to: {1}."},
{"R_invalidMsiTokenCacheTtl", "msiTokenCacheTtl {0} is not valid."},
{"R_invalidConnectRetryCount", "Connection retry count {0} is not valid."},
{"R_connectRetryCountPropertyDescription", "The maximum number of attempts to reestablish a broken connection."},
{"R_invalidConnectRetryInterval", "Connection retry interval {0} is not valid."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
Expand Down Expand Up @@ -318,17 +319,31 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
}
}

private static SimpleTtlCache<String, SqlFedAuthToken> msiTokenCache = new SimpleTtlCache<>();

/**
* Get Managed Identity Authentication token
*
* @param resource
* token resource
* @param msiClientId
* Managed Identity or User Assigned Managed Identity
* @param tokenCacheTtl
* The number of seconds the token should remain in the cache
* @return fedauth token
* @throws SQLServerException
*/
static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) throws SQLServerException {
static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId,
int tokenCacheTtl) throws SQLServerException {
String cacheKey = "resource:" + resource + "|clientid:" + msiClientId;
SqlFedAuthToken token = msiTokenCache.get(cacheKey);
if (token != null) {
if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.finer("Using cached Managed Identity auth token: " + token.toString());
}
return token;
}

// IMDS upgrade time can take up to 70s
final int imdsUpgradeTimeInMs = 70 * 1000;
final List<Integer> retrySlots = new ArrayList<>();
Expand Down Expand Up @@ -419,11 +434,23 @@ static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) thro
+ ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER.length();
}

String accessTokenExpiry = result.substring(startIndex_ATX,
result.indexOf("\"", startIndex_ATX + 1));
cal.add(Calendar.SECOND, Integer.parseInt(accessTokenExpiry));
int accessTokenExpiry = Integer
.parseInt(result.substring(startIndex_ATX, result.indexOf("\"", startIndex_ATX + 1)));
cal.add(Calendar.SECOND, accessTokenExpiry);
token = new SqlFedAuthToken(accessToken, cal.getTime());

if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.finer("Obtained new Managed Identity auth token: " + token.toString());
}

if (tokenCacheTtl > 0) {
// Cache the token for up to tokenCacheTtl but not longer than 5 minutes less than the token's
// expiration, in case we are given a token with a very short lifetime.
msiTokenCache.put(cacheKey, token,
Duration.ofSeconds(Math.min(tokenCacheTtl, accessTokenExpiry - 300)));
}

return new SqlFedAuthToken(accessToken, cal.getTime());
return token;
}
} catch (Exception e) {
retry++;
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/SimpleTtlCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,29 @@ V put(K key, V value) {
return previousValue;
}

/**
* Put (Key, Value, TTL) entry into cache.
*
* @param key
* key
* @param value
* value
* @param ttl
* Time-To-Live for this cache entry
* @return value
*/
V put(K key, V value, Duration ttl) {
V previousValue = null;
long cacheTtlInSeconds = ttl.getSeconds();

if (0 < cacheTtlInSeconds) {
previousValue = cache.put(key, value);
if (simpleCacheLogger.isLoggable(java.util.logging.Level.FINEST)) {
simpleCacheLogger.fine("Adding encryption key to cache...");
}
scheduler.schedule(new CacheClear(key), cacheTtlInSeconds, SECONDS);
}

return previousValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ class SqlFedAuthToken implements Serializable {
this.accessToken = accessToken;
this.expiresOn = expiresOn;
}

public String toString() {
return "accessToken hashCode: " + accessToken.hashCode() + " expiresOn: " + expiresOn.toInstant().toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
*/
package com.microsoft.sqlserver.jdbc.AlwaysEncrypted;

import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
Expand Down Expand Up @@ -59,6 +57,21 @@ public void testMSIAuth() throws SQLException {
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, "");
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI");

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "0");

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL,
Integer.toString(Integer.MAX_VALUE));

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "");

testSimpleConnect(connStr); // This call will use a cached token
}

private void testSimpleConnect(String connStr) {
try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
Expand All @@ -79,9 +92,18 @@ public void testMSIAuthWithMSIClientId() throws SQLException {
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI");
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSICLIENTID, msiClientId);

try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "0");

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL,
Integer.toString(Integer.MAX_VALUE));

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "");

testSimpleConnect(connStr); // This call will use a cached token
}

/*
Expand All @@ -101,9 +123,7 @@ public void testDSMSIAuth() throws SQLException {
ds.setAuthentication("ActiveDirectoryMSI");
AbstractTest.updateDataSource(connStr, ds);

try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
testSimpleConnect(connStr);
}

/*
Expand All @@ -124,9 +144,7 @@ public void testDSMSIAuthWithMSIClientId() throws SQLException {
ds.setMSIClientId(msiClientId);
AbstractTest.updateDataSource(connStr, ds);

try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
testSimpleConnect(connStr);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ private Constants() {}
public static final String ENCLAVE_ATTESTATIONPROTOCOL = "enclaveAttestationProtocol";

public static final String MSICLIENTID = "MSICLIENTID";
public static final String MSITOKENCACHETTL = "MSITOKENCACHETTL";
public static final String KEYVAULTPROVIDER_CLIENTID = "KEYVAULTPROVIDERCLIENTID";
public static final String KEYVAULTPROVIDER_CLIENTKEY = "KEYVAULTPROVIDERCLIENTKEY";
public static final String KEYSTORE_AUTHENTICATION = "KEYSTOREAUTHENTICATION";
Expand Down

0 comments on commit 7ffc2f0

Please sign in to comment.