Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Managed Identity token cache #1825

Merged
merged 2 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial
/** encrypted truststore password */
byte[] encryptedTrustStorePassword = null;

/** cached MSI token time-to-live */
private int cachedMsiTokenTtl = 0;

/**
* Return an existing cached SharedTimer associated with this Connection or create a new one.
*
Expand Down Expand Up @@ -2636,6 +2639,26 @@ else if (0 == requestedPacketSize)
activeConnectionProperties.setProperty(sPropKey, sPropValue);
}

cachedMsiTokenTtl = SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.getDefaultValue();
sPropValue = activeConnectionProperties
.getProperty(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.toString());
if (null != sPropValue && sPropValue.length() > 0) {
try {
cachedMsiTokenTtl = Integer.parseInt(sPropValue);
} catch (NumberFormatException e) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_invalidMsiTokenCacheTtl"));
Object[] msgArgs = {sPropValue};
SQLServerException.makeFromDriverError(this, this, form.format(msgArgs), null, false);
}
if (cachedMsiTokenTtl < 0) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_invalidMsiTokenCacheTtl"));
Object[] msgArgs = {sPropValue};
SQLServerException.makeFromDriverError(this, this, form.format(msgArgs), null, false);
}
}

sPropKey = SQLServerDriverStringProperty.CLIENT_CERTIFICATE.toString();
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null != sPropValue) {
Expand Down Expand Up @@ -5405,7 +5428,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe

String user = activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString());

// No:of milliseconds to sleep for the inital back off.
// No:of milliseconds to sleep for the initial back off.
int sleepInterval = 100;

while (true) {
Expand All @@ -5422,7 +5445,8 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
break;
} else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) {
fedAuthToken = SQLServerSecurityUtility.getMSIAuthToken(fedAuthInfo.spn,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()),
cachedMsiTokenTtl);

// Break out of the retry loop in successful case.
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ enum SQLServerDriverIntProperty {
STATEMENT_POOLING_CACHE_SIZE("statementPoolingCacheSize", SQLServerConnection.DEFAULT_STATEMENT_POOLING_CACHE_SIZE),
CANCEL_QUERY_TIMEOUT("cancelQueryTimeout", -1),
CONNECT_RETRY_COUNT("connectRetryCount", 1, 0, 255),
CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60);
CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60),
MSI_TOKEN_CACHE_TTL("msiTokenCacheTtl", 3600, 0, Integer.MAX_VALUE);

private final String name;
private final int defaultValue;
Expand Down Expand Up @@ -775,6 +776,9 @@ public final class SQLServerDriver implements java.sql.Driver {
false, TRUE_FALSE),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(),
SQLServerDriverStringProperty.MSI_CLIENT_ID.getDefaultValue(), false, null),
new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.toString(),
Integer.toString(SQLServerDriverIntProperty.MSI_TOKEN_CACHE_TTL.getDefaultValue()), false,
null),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_ID.toString(),
SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_ID.getDefaultValue(), false, null),
new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_KEY.toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ protected Object[][] getContents() {
{"R_lockTimeoutPropertyDescription", "The number of milliseconds to wait before the database reports a lock time-out."},
{"R_connectRetryCountPropertyDescription", "The number of reconnection attempts if there is a connection failure."},
{"R_connectRetryIntervalPropertyDescription", "The number of seconds between each connection retry attempt."},
{"R_msiTokenCacheTtlPropertyDescription", "The number of seconds a Managed Identity (MSI) access token should be cached."},
{"R_loginTimeoutPropertyDescription", "The number of seconds the driver should wait before timing out a failed connection."},
{"R_instanceNamePropertyDescription", "The name of the SQL Server instance to connect to."},
{"R_xopenStatesPropertyDescription", "Determines if the driver returns XOPEN-compliant SQL state codes in exceptions."},
Expand Down Expand Up @@ -499,6 +500,7 @@ protected Object[][] getContents() {
{"R_maxResultBufferInvalidSyntax", "Invalid syntax: {0} in maxResultBuffer parameter."},
{"R_maxResultBufferNegativeParameterValue", "MaxResultBuffer must have positive value: {0}."},
{"R_maxResultBufferPropertyExceeded", "MaxResultBuffer property exceeded: {0}. MaxResultBuffer was set to: {1}."},
{"R_invalidMsiTokenCacheTtl", "msiTokenCacheTtl {0} is not valid."},
{"R_invalidConnectRetryCount", "Connection retry count {0} is not valid."},
{"R_connectRetryCountPropertyDescription", "The maximum number of attempts to reestablish a broken connection."},
{"R_invalidConnectRetryInterval", "Connection retry interval {0} is not valid."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
Expand Down Expand Up @@ -318,17 +319,31 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
}
}

private static SimpleTtlCache<String, SqlFedAuthToken> msiTokenCache = new SimpleTtlCache<>();

/**
* Get Managed Identity Authentication token
*
* @param resource
* token resource
* @param msiClientId
* Managed Identity or User Assigned Managed Identity
* @param tokenCacheTtl
* The number of seconds the token should remain in the cache
* @return fedauth token
* @throws SQLServerException
*/
static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) throws SQLServerException {
static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId,
int tokenCacheTtl) throws SQLServerException {
String cacheKey = "resource:" + resource + "|clientid:" + msiClientId;
SqlFedAuthToken token = msiTokenCache.get(cacheKey);
if (token != null) {
if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.finer("Using cached Managed Identity auth token: " + token.toString());
}
return token;
}

// IMDS upgrade time can take up to 70s
final int imdsUpgradeTimeInMs = 70 * 1000;
final List<Integer> retrySlots = new ArrayList<>();
Expand Down Expand Up @@ -419,11 +434,23 @@ static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) thro
+ ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER.length();
}

String accessTokenExpiry = result.substring(startIndex_ATX,
result.indexOf("\"", startIndex_ATX + 1));
cal.add(Calendar.SECOND, Integer.parseInt(accessTokenExpiry));
int accessTokenExpiry = Integer
.parseInt(result.substring(startIndex_ATX, result.indexOf("\"", startIndex_ATX + 1)));
cal.add(Calendar.SECOND, accessTokenExpiry);
token = new SqlFedAuthToken(accessToken, cal.getTime());

if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.finer("Obtained new Managed Identity auth token: " + token.toString());
}

if (tokenCacheTtl > 0) {
// Cache the token for up to tokenCacheTtl but not longer than 5 minutes less than the token's
// expiration, in case we are given a token with a very short lifetime.
msiTokenCache.put(cacheKey, token,
Duration.ofSeconds(Math.min(tokenCacheTtl, accessTokenExpiry - 300)));
}

return new SqlFedAuthToken(accessToken, cal.getTime());
return token;
}
} catch (Exception e) {
retry++;
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/SimpleTtlCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,29 @@ V put(K key, V value) {
return previousValue;
}

/**
* Put (Key, Value, TTL) entry into cache.
*
* @param key
* key
* @param value
* value
* @param ttl
* Time-To-Live for this cache entry
* @return value
*/
V put(K key, V value, Duration ttl) {
V previousValue = null;
long cacheTtlInSeconds = ttl.getSeconds();

if (0 < cacheTtlInSeconds) {
previousValue = cache.put(key, value);
if (simpleCacheLogger.isLoggable(java.util.logging.Level.FINEST)) {
simpleCacheLogger.fine("Adding encryption key to cache...");
}
scheduler.schedule(new CacheClear(key), cacheTtlInSeconds, SECONDS);
}

return previousValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ class SqlFedAuthToken implements Serializable {
this.accessToken = accessToken;
this.expiresOn = expiresOn;
}

public String toString() {
return "accessToken hashCode: " + accessToken.hashCode() + " expiresOn: " + expiresOn.toInstant().toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
*/
package com.microsoft.sqlserver.jdbc.AlwaysEncrypted;

import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
Expand Down Expand Up @@ -59,6 +57,21 @@ public void testMSIAuth() throws SQLException {
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, "");
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI");

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "0");

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL,
Integer.toString(Integer.MAX_VALUE));

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "");

testSimpleConnect(connStr); // This call will use a cached token
}

private void testSimpleConnect(String connStr) {
try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
Expand All @@ -79,9 +92,18 @@ public void testMSIAuthWithMSIClientId() throws SQLException {
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI");
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSICLIENTID, msiClientId);

try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "0");

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL,
Integer.toString(Integer.MAX_VALUE));

testSimpleConnect(connStr);

connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSITOKENCACHETTL, "");

testSimpleConnect(connStr); // This call will use a cached token
}

/*
Expand All @@ -101,9 +123,7 @@ public void testDSMSIAuth() throws SQLException {
ds.setAuthentication("ActiveDirectoryMSI");
AbstractTest.updateDataSource(connStr, ds);

try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
testSimpleConnect(connStr);
}

/*
Expand All @@ -124,9 +144,7 @@ public void testDSMSIAuthWithMSIClientId() throws SQLException {
ds.setMSIClientId(msiClientId);
AbstractTest.updateDataSource(connStr, ds);

try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) {
fail(TestResource.getResource("R_loginFailed") + e.getMessage());
}
testSimpleConnect(connStr);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ private Constants() {}
public static final String ENCLAVE_ATTESTATIONPROTOCOL = "enclaveAttestationProtocol";

public static final String MSICLIENTID = "MSICLIENTID";
public static final String MSITOKENCACHETTL = "MSITOKENCACHETTL";
public static final String KEYVAULTPROVIDER_CLIENTID = "KEYVAULTPROVIDERCLIENTID";
public static final String KEYVAULTPROVIDER_CLIENTKEY = "KEYVAULTPROVIDERCLIENTKEY";
public static final String KEYSTORE_AUTHENTICATION = "KEYSTOREAUTHENTICATION";
Expand Down