Skip to content

Commit

Permalink
Fix calling procedures with output parameters by their four-part synt…
Browse files Browse the repository at this point in the history
…ax (#2349)

* Corrected four part syntax regression

* JDK 8 correction
  • Loading branch information
tkyc committed Mar 19, 2024
1 parent 41710d2 commit aa46637
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public class SQLServerPreparedStatement extends SQLServerStatement implements IS

private boolean isCallEscapeSyntax;

private boolean isFourPartSyntax;

/** Parameter positions in processed SQL statement text. */
final int[] userSQLParamPositions;

Expand Down Expand Up @@ -144,6 +146,11 @@ private void setPreparedStatementHandle(int handle) {
*/
private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b");

/**
* Regex for four part syntax
*/
private static final Pattern fourPartSyntaxPattern = Pattern.compile("(.+)\\.(.+)\\.(.+)\\.(.+)");

/** Returns the prepared statement SQL */
@Override
public String toString() {
Expand Down Expand Up @@ -271,6 +278,7 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
userSQL = parsedSQL.processedSQL;
isExecEscapeSyntax = isExecEscapeSyntax(sql);
isCallEscapeSyntax = isCallEscapeSyntax(sql);
isFourPartSyntax = isFourPartSyntax(sql);
userSQLParamPositions = parsedSQL.parameterPositions;
initParams(userSQLParamPositions.length);
useBulkCopyForBatchInsert = conn.getUseBulkCopyForBatchInsert();
Expand Down Expand Up @@ -1234,10 +1242,12 @@ boolean callRPCDirectly(Parameter[] params) throws SQLServerException {
// 2. There must be parameters
// 3. Parameters must not be a TVP type
// 4. Compliant CALL escape syntax
// If isExecEscapeSyntax is true, EXEC escape syntax is used then use prior behaviour to
// execute the procedure
// If isExecEscapeSyntax is true, EXEC escape syntax is used then use prior behaviour of
// wrapping call to execute the procedure
// If isFourPartSyntax is true, sproc is being executed against linked server, then
// use prior behaviour of wrapping call to execute procedure
return (null != procedureName && paramCount != 0 && !isTVPType(params) && isCallEscapeSyntax
&& !isExecEscapeSyntax);
&& !isExecEscapeSyntax && !isFourPartSyntax);
}

/**
Expand Down Expand Up @@ -1265,6 +1275,10 @@ private boolean isCallEscapeSyntax(String sql) {
return callEscapePattern.matcher(sql).find();
}

private boolean isFourPartSyntax(String sql) {
return fourPartSyntaxPattern.matcher(sql).find();
}

/**
* Executes sp_prepare to prepare a parameterized statement and sets the prepared statement handle
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,66 @@ public void testExecDocumentedSystemStoredProceduresIndexedParameters() throws S
}
}

@Test
@Tag(Constants.reqExternalSetup)
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.xAzureSQLMI)
public void testFourPartSyntaxCallEscapeSyntax() throws SQLException {
String table = "serverList";

try (Statement stmt = connection.createStatement()) {
stmt.execute("IF OBJECT_ID(N'" + table + "') IS NOT NULL DROP TABLE " + table);
stmt.execute("CREATE TABLE " + table + " (serverName varchar(100),network varchar(100),serverStatus varchar(4000), id int, collation varchar(100), connectTimeout int, queryTimeout int)");
stmt.execute("INSERT " + table + " EXEC sp_helpserver");

ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + table + " WHERE serverName = N'" + linkedServer + "'");
rs.next();

if (rs.getInt(1) == 1) {
stmt.execute("EXEC sp_dropserver @server='" + linkedServer + "';");
}

stmt.execute("EXEC sp_addlinkedserver @server='" + linkedServer + "';");
stmt.execute("EXEC sp_addlinkedsrvlogin @rmtsrvname=N'" + linkedServer + "', @rmtuser=N'" + remoteUser + "', @rmtpassword=N'" + remotePassword + "'");
stmt.execute("EXEC sp_serveroption '" + linkedServer + "', 'rpc', true;");
stmt.execute("EXEC sp_serveroption '" + linkedServer + "', 'rpc out', true;");
}

SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName(linkedServer);
ds.setUser(remoteUser);
ds.setPassword(remotePassword);
ds.setEncrypt(false);
ds.setTrustServerCertificate(true);

try (Connection linkedServerConnection = ds.getConnection(); Statement stmt = linkedServerConnection.createStatement()) {
stmt.execute("create or alter procedure dbo.TestAdd(@Num1 int, @Num2 int, @Result int output) as begin set @Result = @Num1 + @Num2; end;");
}

try (CallableStatement cstmt = connection.prepareCall("{call [" + linkedServer + "].master.dbo.TestAdd(?,?,?)}")) {
int sum = 11;
int param0 = 1;
int param1 = 10;
cstmt.setInt(1, param0);
cstmt.setInt(2, param1);
cstmt.registerOutParameter(3, Types.INTEGER);
cstmt.execute();
assertEquals(sum, cstmt.getInt(3));
}

try (CallableStatement cstmt = connection.prepareCall("exec [" + linkedServer + "].master.dbo.TestAdd ?,?,?")) {
int sum = 11;
int param0 = 1;
int param1 = 10;
cstmt.setInt(1, param0);
cstmt.setInt(2, param1);
cstmt.registerOutParameter(3, Types.INTEGER);
cstmt.execute();
assertEquals(sum, cstmt.getInt(3));
}
}

/**
* Cleanup after test
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public abstract class AbstractTest {
protected static String tenantID;
protected static String[] keyIDs = null;

protected static String linkedServer = null;
protected static String remoteUser = null;
protected static String remotePassword = null;
protected static String[] enclaveServer = null;
protected static String[] enclaveAttestationUrl = null;
protected static String[] enclaveAttestationProtocol = null;
Expand Down Expand Up @@ -197,6 +200,10 @@ public static void setup() throws Exception {

clientKeyPassword = getConfiguredProperty("clientKeyPassword", "");

linkedServer = getConfiguredProperty("linkedServer", null);
remoteUser = getConfiguredProperty("remoteUser", null);
remotePassword = getConfiguredProperty("remotePassword", null);

kerberosServer = getConfiguredProperty("kerberosServer", null);
kerberosServerPort = getConfiguredProperty("kerberosServerPort", null);

Expand Down

0 comments on commit aa46637

Please sign in to comment.