Skip to content

Commit

Permalink
Clear prepared statement handle before reconnect (#2364)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasaignerrsg committed Apr 3, 2024
1 parent 9de1a5d commit 51ca5a0
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/ReconnectListener.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
* available under the terms of the MIT License. See the LICENSE file in the project root for more information.
*/
package com.microsoft.sqlserver.jdbc;

/**
* This functional interface represents a listener which is called before a reconnect of {@link SQLServerConnection}.
*/
@FunctionalInterface
public interface ReconnectListener {

void beforeReconnect();

}
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,19 @@ SQLServerPooledConnection getPooledConnectionParent() {
return pooledConnectionParent;
}

/**
* List of listeners which are called before reconnecting.
*/
private List<ReconnectListener> reconnectListeners = new ArrayList<>();

public void registerBeforeReconnectListener(ReconnectListener reconnectListener) {
reconnectListeners.add(reconnectListener);
}

public void removeBeforeReconnectListener(ReconnectListener reconnectListener) {
reconnectListeners.remove(reconnectListener);
}

SQLServerConnection(String parentInfo) {
int connectionID = nextConnectionID(); // sequential connection id
traceID = "ConnectionID:" + connectionID;
Expand Down Expand Up @@ -4345,6 +4358,8 @@ boolean executeCommand(TDSCommand newCommand) throws SQLServerException {
preparedStatementHandleCache.clear();
}

this.reconnectListeners.forEach(ReconnectListener::beforeReconnect);

if (loggerResiliency.isLoggable(Level.FINE)) {
loggerResiliency.fine(toString()
+ " Idle connection resiliency - starting idle connection resiliency reconnect.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
*/
private Vector<CryptoMetadata> cryptoMetaBatch = new Vector<>();

/**
* Listener to clear the {@link SQLServerPreparedStatement#prepStmtHandle} and
* {@link SQLServerPreparedStatement#cachedPreparedStatementHandle} before reconnecting.
*/
private ReconnectListener clearPrepStmtHandleOnReconnectListener;

/**
* Constructs a SQLServerPreparedStatement.
*
Expand All @@ -254,6 +260,9 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException {
super(conn, nRSType, nRSConcur, stmtColEncSetting);

clearPrepStmtHandleOnReconnectListener = this::clearPrepStmtHandle;
connection.registerBeforeReconnectListener(clearPrepStmtHandleOnReconnectListener);

if (null == sql) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
Object[] msgArgs1 = {"Statement SQL"};
Expand Down Expand Up @@ -291,6 +300,8 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
* Closes the prepared statement's prepared handle.
*/
private void closePreparedHandle() {
connection.removeBeforeReconnectListener(clearPrepStmtHandleOnReconnectListener);

if (!hasPreparedStatementHandle())
return;

Expand Down Expand Up @@ -3586,4 +3597,12 @@ public void addBatch(String sql) throws SQLServerException {
Object[] msgArgs = {"addBatch()"};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
}

private void clearPrepStmtHandle() {
prepStmtHandle = 0;
cachedPreparedStatementHandle = null;
if (getStatementLogger().isLoggable(Level.FINER)) {
getStatementLogger().finer(toString() + " cleared cachedPrepStmtHandle!");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ private List<String> getVerifiedMethodNames() {
verifiedMethodNames.add("setUseFlexibleCallableStatements");
verifiedMethodNames.add("getCalcBigDecimalPrecision");
verifiedMethodNames.add("setCalcBigDecimalPrecision");
verifiedMethodNames.add("registerBeforeReconnectListener");
verifiedMethodNames.add("removeBeforeReconnectListener");
return verifiedMethodNames;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
package com.microsoft.sqlserver.jdbc.resiliency;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;

import javax.sql.PooledConnection;
Expand Down Expand Up @@ -367,6 +372,100 @@ public void testPreparedStatementCacheShouldBeCleared() throws SQLException {
}
}

@Test
public void testPreparedStatementHandleOfStatementShouldBeCleared() throws SQLException {
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {
int cacheSize = 2;
String query = String.format("/*testPreparedStatementHandleOfStatementShouldBeCleared%s*/SELECT 1; -- ",
UUID.randomUUID().toString());

// enable caching
con.setDisableStatementPooling(false);
con.setStatementPoolingCacheSize(cacheSize);
con.setServerPreparedStatementDiscardThreshold(cacheSize);

List<SQLServerPreparedStatement> statements = new LinkedList<>();

// add statements to fill cache
for (int i = 0; i < cacheSize + 1; ++i) {
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(query + i);
pstmt.execute();
pstmt.execute();
pstmt.execute();
pstmt.getMoreResults();
statements.add(pstmt);
}

// handle of the prepared statement should be set
assertNotEquals(0, statements.get(1).getPreparedStatementHandle());

ResiliencyUtils.killConnection(con, connectionString, 1);

// call first statement to trigger reconnect
statements.get(0).execute();

// handle of the other statements should be cleared after reconnect
assertEquals(0, statements.get(1).getPreparedStatementHandle());
}
}

@Test
public void testPreparedStatementShouldNotUseWrongHandleAfterReconnect() throws SQLException {
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {
int cacheSize = 3;
String queryOne = "select * from sys.sysusers where name=?;";
String queryTwo = "select * from sys.sysusers where name=? and uid=?;";
String queryThree = "select * from sys.sysusers where name=? and uid=? and islogin=?";

String parameterOne = "name";
int parameterUid = 0;
int parameterIsLogin = 0;

// enable caching
con.setDisableStatementPooling(false);
con.setStatementPoolingCacheSize(cacheSize);
con.setServerPreparedStatementDiscardThreshold(cacheSize);

List<PreparedStatement> statements = new LinkedList<>();

PreparedStatement ps = con.prepareStatement(queryOne);
ps.setString(1, parameterOne);
statements.add(ps);

ps = con.prepareStatement(queryTwo);
ps.setString(1, parameterOne);
ps.setInt(2, parameterUid);
statements.add(ps);

ps = con.prepareStatement(queryThree);
ps.setString(1, parameterOne);
ps.setInt(2, parameterUid);
ps.setInt(3, parameterIsLogin);
statements.add(ps);

// add new statements to fill cache
for (PreparedStatement preparedStatement : statements) {
preparedStatement.execute();
preparedStatement.execute();
preparedStatement.execute();
preparedStatement.getMoreResults();
}

ResiliencyUtils.killConnection(con, connectionString, 1);

// call statements in reversed order, in order to force the statement to use the wrong handle
// first execute triggers a reconnect
Collections.reverse(statements);
for (PreparedStatement preparedStatement : statements) {
preparedStatement.execute();
preparedStatement.execute();
preparedStatement.execute();
preparedStatement.getMoreResults();
}
}
}


@Test
public void testUnprocessedResponseCountSuccessfulIdleConnectionRecovery() throws SQLException {
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {
Expand Down

0 comments on commit 51ca5a0

Please sign in to comment.