Skip to content

Commit

Permalink
Backport of #1897 - batch query cancellation fix (#1996)
Browse files Browse the repository at this point in the history
* Backport of #1897 - batch query cancellation fix

* Updated SQLJdbcVersion file

* 10.2.2 updates for pipeline testing

* PR comments
  • Loading branch information
tkyc committed Dec 12, 2022
1 parent 9f21342 commit 9588222
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Expand Up @@ -7537,7 +7537,7 @@ protected void setInterruptsEnabled(boolean interruptsEnabled) {
// Flag set to indicate that an interrupt has happened.
private volatile boolean wasInterrupted = false;

private boolean wasInterrupted() {
boolean wasInterrupted() {
return wasInterrupted;
}

Expand Down
Expand Up @@ -8,7 +8,7 @@
final class SQLJdbcVersion {
static final int major = 10;
static final int minor = 2;
static final int patch = 1;
static final int patch = 2;
static final int build = 0;
/*
* Used to load mssql-jdbc_auth DLL.
Expand Down
Expand Up @@ -2804,6 +2804,21 @@ final void doExecutePreparedStatementBatch(PrepStmtBatchExecCmd batchCommand) th
for (int attempt = 1; attempt <= 2; ++attempt) {
try {

// If the command was interrupted, that means the TDS.PKT_CANCEL_REQ was sent to the server.
// Since the cancelation request was sent, stop processing the batch query and process the
// cancelation request and then return.
//
// Otherwise, if we do continue processing the batch query, in the case where a query requires
// prepexec/sp_prepare, the TDS request for prepexec/sp_prepare will be sent regardless of
// query cancelation. This will cause a TDS token error in the post processing when we
// close the query.
if (batchCommand.wasInterrupted()) {
ensureExecuteResultsReader(batchCommand.startResponse(getIsResponseBufferingAdaptive()));
startResults();
getNextResult(true);
return;
}

// Re-use handle if available, requires parameter definitions which are not available until here.
if (reuseCachedHandle(hasNewTypeDefinitions, 1 < attempt)) {
hasNewTypeDefinitions = false;
Expand Down
Expand Up @@ -48,7 +48,7 @@ protected Object[][] getContents() {
{"R_noServerResponse", "SQL Server did not return a response. The connection has been closed."},
{"R_truncatedServerResponse", "SQL Server returned an incomplete response. The connection has been closed."},
{"R_queryTimedOut", "The query has timed out."},
{"R_queryCancelled", "The query was canceled."},
{"R_queryCanceled", "The query was canceled."},
{"R_errorReadingStream", "An error occurred while reading the value from the stream object. Error: \"{0}\""},
{"R_streamReadReturnedInvalidValue", "The stream read operation returned an invalid value for the amount of data read."},
{"R_mismatchedStreamLength", "The stream value is not the specified length. The specified length was {0}, the actual length is {1}."},
Expand Down
Expand Up @@ -161,7 +161,6 @@ protected Object[][] getContents() {
{"R_cancellationFailed", "Cancellation failed."}, {"R_executionNotTimeout", "Execution did not timeout."},
{"R_executionTooLong", "Execution took too long."},
{"R_executionNotLong", "Execution did not take long enough."},
{"R_queryCancelled", "The query was canceled."},
{"R_statementShouldBeClosed", "statement should be closed since resultset is closed."},
{"R_statementShouldBeOpened", "statement should be opened since resultset is opened."},
{"R_shouldBeWrapper", "{0} should be a wrapper for {1}."},
Expand Down Expand Up @@ -201,5 +200,6 @@ protected Object[][] getContents() {
{"R_objectNullOrEmpty", "The {0} is null or empty."},
{"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."},
{"R_connectTimedOut", "connect timed out"},
{"R_queryCanceled", "The query was canceled."},
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}};
}
Expand Up @@ -152,7 +152,7 @@ public void testCallableStatementManyParameters() throws SQLException {
@Test
public void getStringGUIDTest() throws SQLException {

String sql = "{call " + AbstractSQLGenerator.escapeIdentifier(outputProcedureNameGUID) + "(?)}";
String sql = "{call " + outputProcedureNameGUID + "(?)}";

try (SQLServerCallableStatement callableStatement = (SQLServerCallableStatement) connection.prepareCall(sql)) {

Expand Down Expand Up @@ -181,7 +181,7 @@ public void getSetNullWithTypeVarchar() throws SQLException {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setURL(connectionString);
ds.setSendStringParametersAsUnicode(true);
String sql = "{? = call " + AbstractSQLGenerator.escapeIdentifier(setNullProcedureName) + " (?,?)}";
String sql = "{? = call " + setNullProcedureName + " (?,?)}";
try (Connection connection = ds.getConnection();
SQLServerCallableStatement cs = (SQLServerCallableStatement) connection.prepareCall(sql);
SQLServerCallableStatement cs2 = (SQLServerCallableStatement) connection.prepareCall(sql)) {
Expand Down Expand Up @@ -213,7 +213,7 @@ public void getSetNullWithTypeVarchar() throws SQLException {
*/
@Test
public void testGetObjectAsLocalDateTime() throws SQLException {
String sql = "{CALL " + AbstractSQLGenerator.escapeIdentifier(getObjectLocalDateTimeProcedureName) + " (?)}";
String sql = "{CALL " + getObjectLocalDateTimeProcedureName + " (?)}";
try (Connection con = DriverManager.getConnection(connectionString);
CallableStatement cs = con.prepareCall(sql)) {
cs.registerOutParameter(1, Types.TIMESTAMP);
Expand Down Expand Up @@ -253,7 +253,7 @@ public void testGetObjectAsLocalDateTime() throws SQLException {
@Test
@Tag(Constants.xAzureSQLDW)
public void testGetObjectAsOffsetDateTime() throws SQLException {
String sql = "{CALL " + AbstractSQLGenerator.escapeIdentifier(getObjectOffsetDateTimeProcedureName)
String sql = "{CALL " + getObjectOffsetDateTimeProcedureName
+ " (?, ?)}";
try (Connection con = DriverManager.getConnection(connectionString);
CallableStatement cs = con.prepareCall(sql)) {
Expand Down Expand Up @@ -283,7 +283,7 @@ public void testGetObjectAsOffsetDateTime() throws SQLException {
*/
@Test
public void inputParamsTest() throws SQLException {
String call = "{CALL " + AbstractSQLGenerator.escapeIdentifier(inputParamsProcedureName) + " (?,?)}";
String call = "{CALL " + inputParamsProcedureName + " (?,?)}";

// the historical way: no leading '@', parameter names respected (not positional)
try (CallableStatement cs = connection.prepareCall(call)) {
Expand Down Expand Up @@ -338,39 +338,39 @@ public static void cleanup() throws SQLException {
}

private static void createGUIDStoredProcedure(Statement stmt) throws SQLException {
String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(outputProcedureNameGUID)
String sql = "CREATE PROCEDURE " + outputProcedureNameGUID
+ "(@p1 uniqueidentifier OUTPUT) AS SELECT @p1 = c1 FROM "
+ AbstractSQLGenerator.escapeIdentifier(tableNameGUID) + Constants.SEMI_COLON;
+ tableNameGUID + Constants.SEMI_COLON;
stmt.execute(sql);
}

private static void createGUIDTable(Statement stmt) throws SQLException {
String sql = "CREATE TABLE " + AbstractSQLGenerator.escapeIdentifier(tableNameGUID)
String sql = "CREATE TABLE " + tableNameGUID
+ " (c1 uniqueidentifier null)";
stmt.execute(sql);
}

private static void createSetNullProcedure(Statement stmt) throws SQLException {
stmt.execute("create procedure " + AbstractSQLGenerator.escapeIdentifier(setNullProcedureName)
stmt.execute("create procedure " + setNullProcedureName
+ " (@p1 nvarchar(255), @p2 nvarchar(255) output) as select @p2=@p1 return 0");
}

private static void createInputParamsProcedure(Statement stmt) throws SQLException {
String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(inputParamsProcedureName)
String sql = "CREATE PROCEDURE " + inputParamsProcedureName
+ " @p1 nvarchar(max) = N'parameter1', " + " @p2 nvarchar(max) = N'parameter2' " + "AS "
+ "BEGIN " + " SET NOCOUNT ON; " + " SELECT @p1 + @p2 AS result; " + "END ";

stmt.execute(sql);
}

private static void createGetObjectLocalDateTimeProcedure(Statement stmt) throws SQLException {
String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(getObjectLocalDateTimeProcedureName)
String sql = "CREATE PROCEDURE " + getObjectLocalDateTimeProcedureName
+ "(@p1 datetime2(7) OUTPUT) AS " + "SELECT @p1 = '2018-03-11T02:00:00.1234567'";
stmt.execute(sql);
}

private static void createGetObjectOffsetDateTimeProcedure(Statement stmt) throws SQLException {
String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(getObjectOffsetDateTimeProcedureName)
String sql = "CREATE PROCEDURE " + getObjectOffsetDateTimeProcedureName
+ "(@p1 DATETIMEOFFSET OUTPUT, @p2 DATETIMEOFFSET OUTPUT) AS "
+ "SELECT @p1 = '2018-01-02T11:22:33.123456700+12:34', @p2 = NULL";
stmt.execute(sql);
Expand Down
Expand Up @@ -5,6 +5,7 @@
package com.microsoft.sqlserver.jdbc.unit.statement;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

import java.lang.reflect.Field;
Expand Down Expand Up @@ -60,6 +61,51 @@ public void testBatchExceptionAEOn() throws Exception {
testExecuteBatch1UseBulkCopyAPI();
}

@Test
public void testBatchStatementCancellation() throws Exception {
try (Connection connection = PrepUtil.getConnection(connectionString)) {
connection.setAutoCommit(false);

try (PreparedStatement statement = connection.prepareStatement(
"if object_id('test_table') is not null drop table test_table")) {
statement.execute();
}
connection.commit();

try (PreparedStatement statement = connection.prepareStatement(
"create table test_table (column_name bit)")) {
statement.execute();
}
connection.commit();

for (long delayInMilliseconds : new long[] { 1, 2, 4, 8, 16, 32, 64, 128 }) {
for (int numberOfCommands : new int[] { 1, 2, 4, 8, 16, 32, 64 }) {
int parameterCount = 512;

try (PreparedStatement statement = connection.prepareStatement(
"insert into test_table values (?)" + repeat(",(?)", parameterCount - 1))) {

for (int i = 0; i < numberOfCommands; i++) {
for (int j = 0; j < parameterCount; j++) {
statement.setBoolean(j + 1, true);
}
statement.addBatch();
}

Thread cancelThread = cancelAsync(statement, delayInMilliseconds);
try {
statement.executeBatch();
} catch (SQLException e) {
assertEquals(TestResource.getResource("R_queryCancelled"), e.getMessage());
}
cancelThread.join();
}
connection.commit();
}
}
}
}

/**
* Get a PreparedStatement object and call the addBatch() method with 3 SQL statements and call the executeBatch()
* method and it should return array of Integer values of length 3
Expand Down Expand Up @@ -231,6 +277,29 @@ private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exce
con.setUseBulkCopyForBatchInsert(true);
}

private static String repeat(String string, int count) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < count; i++) {
sb.append(string);
}
return sb.toString();
}

private static Thread cancelAsync(Statement statement, long delayInMilliseconds) {
Thread thread = new Thread(() -> {
try {
Thread.sleep(delayInMilliseconds);
statement.cancel();
} catch (SQLException | InterruptedException e) {
// does not/must not happen
e.printStackTrace();
throw new IllegalStateException(e);
}
});
thread.start();
return thread;
}

@BeforeAll
public static void testSetup() throws TestAbortedException, Exception {
connectionString = TestUtils.addOrOverrideProperty(connectionString,"trustServerCertificate", "true");
Expand Down
Expand Up @@ -71,6 +71,7 @@ public abstract class AbstractTest {

protected static String trustStorePath = "";

protected static String trustServerCertificate = "";
protected static String windowsKeyPath = null;
protected static String javaKeyPath = null;
protected static String javaKeyAliases = null;
Expand Down Expand Up @@ -134,6 +135,10 @@ public static void setup() throws Exception {
applicationKey = getConfiguredProperty("applicationKey");
tenantID = getConfiguredProperty("tenantID");

trustServerCertificate = getConfiguredProperty("trustServerCertificate", "true");
connectionString = TestUtils.addOrOverrideProperty(connectionString, "trustServerCertificate",
trustServerCertificate);

javaKeyPath = TestUtils.getCurrentClassPath() + Constants.JKS_NAME;

keyIDs = getConfiguredProperty("keyID", "").split(Constants.SEMI_COLON);
Expand Down

0 comments on commit 9588222

Please sign in to comment.