Skip to content

Commit

Permalink
feat: add tests for DML with Returning clause (#936)
Browse files Browse the repository at this point in the history
This PR adds tests for running DML with Returning clause using the JDBC driver, and incorporates the following:
- Integration tests for running DML statements with Returning clause using PreparedStatement.
- Unit tests for running DML statements with Returning clause using JdbcStatement, for each of the available JDBC APIs `execute`, `executeUpdate`, `executeQuery`, `executeBatchUpdate`.
- The JDBC driver does not require any code changes for supporting DML with Returning clause, as all the required changes will be made in the Connection API (Connection API changes are being tracked at https://togithub.com/googleapis/java-spanner/pull/1978).
  • Loading branch information
rajatbhatta committed Nov 21, 2022
1 parent 9016c38 commit 8a86467
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 2 deletions.
Expand Up @@ -87,7 +87,7 @@ public long executeLargeUpdate(String sql) throws SQLException {
switch (result.getResultType()) {
case RESULT_SET:
throw JdbcSqlExceptionFactory.of(
"The statement is not an update or DDL statement", Code.INVALID_ARGUMENT);
"The statement is not a non-returning DML or DDL statement", Code.INVALID_ARGUMENT);
case UPDATE_COUNT:
return result.getUpdateCount();
case NO_RESULT:
Expand Down
Expand Up @@ -17,6 +17,12 @@
package com.google.cloud.spanner.jdbc;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -53,6 +59,8 @@ public class JdbcStatementTest {
private static final String SELECT = "SELECT 1";
private static final String UPDATE = "UPDATE FOO SET BAR=1 WHERE BAZ=2";
private static final String LARGE_UPDATE = "UPDATE FOO SET BAR=1 WHERE 1=1";
private static final String DML_RETURNING_GSQL = "UPDATE FOO SET BAR=1 WHERE 1=1 THEN RETURN *";
private static final String DML_RETURNING_PG = "UPDATE FOO SET BAR=1 WHERE 1=1 RETURNING *";
private static final String DDL = "CREATE INDEX FOO ON BAR(ID)";

@Parameter public Dialect dialect;
Expand All @@ -62,11 +70,20 @@ public static Object[] data() {
return Dialect.values();
}

private String getDmlReturningSql() {
if (dialect == Dialect.GOOGLE_STANDARD_SQL) {
return DML_RETURNING_GSQL;
}
return DML_RETURNING_PG;
}

@SuppressWarnings("unchecked")
private JdbcStatement createStatement() throws SQLException {
Connection spanner = mock(Connection.class);
when(spanner.getDialect()).thenReturn(dialect);

final String DML_RETURNING_SQL = getDmlReturningSql();

com.google.cloud.spanner.ResultSet resultSet = mock(com.google.cloud.spanner.ResultSet.class);
when(resultSet.next()).thenReturn(true, false);
when(resultSet.getColumnType(0)).thenReturn(Type.int64());
Expand All @@ -88,6 +105,19 @@ private JdbcStatement createStatement() throws SQLException {
when(spanner.execute(com.google.cloud.spanner.Statement.of(LARGE_UPDATE)))
.thenReturn(largeUpdateResult);

com.google.cloud.spanner.ResultSet dmlReturningResultSet =
mock(com.google.cloud.spanner.ResultSet.class);
when(dmlReturningResultSet.next()).thenReturn(true, false);
when(dmlReturningResultSet.getColumnCount()).thenReturn(1);
when(dmlReturningResultSet.getColumnType(0)).thenReturn(Type.int64());
when(dmlReturningResultSet.getLong(0)).thenReturn(1L);

StatementResult dmlReturningResult = mock(StatementResult.class);
when(dmlReturningResult.getResultType()).thenReturn(ResultType.RESULT_SET);
when(dmlReturningResult.getResultSet()).thenReturn(dmlReturningResultSet);
when(spanner.execute(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenReturn(dmlReturningResult);

StatementResult ddlResult = mock(StatementResult.class);
when(ddlResult.getResultType()).thenReturn(ResultType.NO_RESULT);
when(spanner.execute(com.google.cloud.spanner.Statement.of(DDL))).thenReturn(ddlResult);
Expand All @@ -96,6 +126,8 @@ private JdbcStatement createStatement() throws SQLException {
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(UPDATE)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query"));
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenReturn(dmlReturningResultSet);
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DDL)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query"));
Expand All @@ -109,6 +141,10 @@ private JdbcStatement createStatement() throws SQLException {
.thenThrow(
SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT, "not an update"));
when(spanner.executeUpdate(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(
ErrorCode.FAILED_PRECONDITION, "cannot execute dml returning over executeUpdate"));

when(spanner.executeBatchUpdate(anyList()))
.thenAnswer(
Expand Down Expand Up @@ -219,6 +255,20 @@ public void testExecuteWithDdlStatement() throws SQLException {
assertThat(statement.getUpdateCount()).isEqualTo(JdbcConstants.STATEMENT_NO_RESULT);
}

@Test
public void testExecuteWithDmlReturningStatement() throws SQLException {
Statement statement = createStatement();
boolean res = statement.execute(getDmlReturningSql());
assertTrue(res);
assertEquals(statement.getUpdateCount(), JdbcConstants.STATEMENT_RESULT_SET);
try (ResultSet rs = statement.getResultSet()) {
assertNotNull(rs);
assertTrue(rs.next());
assertEquals(rs.getLong(1), 1L);
assertFalse(rs.next());
}
}

@Test
public void testExecuteWithGeneratedKeys() throws SQLException {
Statement statement = createStatement();
Expand Down Expand Up @@ -257,6 +307,17 @@ public void testExecuteQueryWithUpdateStatement() {
}
}

@Test
public void testExecuteQueryWithDmlReturningStatement() throws SQLException {
Statement statement = createStatement();
try (ResultSet rs = statement.executeQuery(getDmlReturningSql())) {
assertNotNull(rs);
assertTrue(rs.next());
assertEquals(rs.getLong(1), 1L);
assertFalse(rs.next());
}
}

@Test
public void testExecuteQueryWithDdlStatement() {
try {
Expand Down Expand Up @@ -353,12 +414,29 @@ public void testExecuteUpdateWithSelectStatement() {
} catch (SQLException e) {
assertThat(
JdbcExceptionMatcher.matchCodeAndMessage(
Code.INVALID_ARGUMENT, "The statement is not an update or DDL statement")
Code.INVALID_ARGUMENT,
"The statement is not a non-returning DML or DDL statement")
.matches(e))
.isTrue();
}
}

@Test
public void testExecuteUpdateWithDmlReturningStatement() {
try {
Statement statement = createStatement();
SQLException e =
assertThrows(SQLException.class, () -> statement.executeUpdate(getDmlReturningSql()));
assertTrue(
JdbcExceptionMatcher.matchCodeAndMessage(
Code.INVALID_ARGUMENT,
"The statement is not a non-returning DML or DDL statement")
.matches(e));
} catch (SQLException e) {
// ignore exception.
}
}

@Test
public void testExecuteUpdateWithDdlStatement() throws SQLException {
Statement statement = createStatement();
Expand Down Expand Up @@ -438,6 +516,19 @@ public void testDmlBatch() throws SQLException {
}
}

@Test
public void testDmlBatchWithDmlReturning() throws SQLException {
try (Statement statement = createStatement()) {
// Verify that multiple batches can be executed on the same statement.
for (int i = 0; i < 2; i++) {
statement.addBatch(getDmlReturningSql());
statement.addBatch(getDmlReturningSql());
statement.addBatch(getDmlReturningSql());
assertArrayEquals(statement.executeBatch(), new int[] {1, 1, 1});
}
}
}

@Test
public void testLargeDmlBatch() throws SQLException {
try (Statement statement = createStatement()) {
Expand Down
Expand Up @@ -36,6 +36,7 @@
import com.google.cloud.spanner.testing.EmulatorSpannerHelper;
import com.google.common.base.Strings;
import com.google.common.io.BaseEncoding;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
Expand Down Expand Up @@ -263,6 +264,24 @@ private void setPreparedStatement(Connection connection, PreparedStatement ps, D
ps.setArray(6, connection.createArrayOf("INT64", this.ticketPrices));
}
}

private void assertEqualsFields(Connection connection, ResultSet rs, Dialect dialect)
throws SQLException {
assertEquals(rs.getLong(1), this.venueId);
assertEquals(rs.getLong(2), this.singerId);
if (dialect == Dialect.POSTGRESQL) {
assertEquals(rs.getString(3), this.concertDate.toString());
assertEquals(rs.getString(4), this.beginTime.toString());
assertEquals(rs.getString(5), this.endTime.toString());
} else {
assertEquals(rs.getDate(3), this.concertDate);
assertEquals(rs.getTimestamp(4), this.beginTime);
assertEquals(rs.getTimestamp(5), this.endTime);
assertArrayEquals(
(Object[]) rs.getArray(6).getArray(),
(Object[]) connection.createArrayOf("INT64", this.ticketPrices).getArray());
}
}
}

private static Date parseDate(String value) {
Expand Down Expand Up @@ -333,6 +352,34 @@ private String getConcertsInsertQuery(Dialect dialect) {
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);";
}

private String getConcertsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime) VALUES (?,?,?,?,?) RETURNING *;";
}
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?) THEN RETURN *;";
}

private String getSingersInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) RETURNING *";
}
return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) THEN RETURN *";
}

private String getAlbumsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) RETURNING *";
}
return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) THEN RETURN *";
}

private String getSongsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) RETURNING *;";
}
return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) THEN RETURN *;";
}

private int getConcertExpectedParamCount(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return 5;
Expand Down Expand Up @@ -1150,6 +1197,103 @@ private void assertDefaultParameterMetaData(ParameterMetaData pmd, int expectedP
}
}

@Test
public void test12_InsertReturningTestData() throws SQLException {
assumeFalse(
"Emulator does not support DML with returning clause",
EmulatorSpannerHelper.isUsingEmulator());
try (Connection connection = createConnection(env, database)) {
connection.setAutoCommit(false);
// Delete existing rows from tables populated by other tests,
// so that this test can populate rows from scratch.
Statement deleteStatements = connection.createStatement();
deleteStatements.addBatch("DELETE FROM Concerts WHERE TRUE");
deleteStatements.addBatch("DELETE FROM Songs WHERE TRUE");
deleteStatements.addBatch("DELETE FROM Albums WHERE TRUE");
deleteStatements.addBatch("DELETE FROM Singers WHERE TRUE");
deleteStatements.executeBatch();
try (PreparedStatement ps =
connection.prepareStatement(getSingersInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 5);
for (Singer singer : createSingers()) {
singer.setPreparedStatement(ps, getDialect());
assertInsertSingerParameterMetadata(ps.getParameterMetaData());
ps.addBatch();
// check that adding the current params to a batch will not reset the metadata
assertInsertSingerParameterMetadata(ps.getParameterMetaData());
}
int[] results = ps.executeBatch();
for (int res : results) {
assertEquals(1, res);
}
}
try (PreparedStatement ps =
connection.prepareStatement(getAlbumsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 4);
for (Album album : createAlbums()) {
ps.setLong(1, album.singerId);
ps.setLong(2, album.albumId);
ps.setString(3, album.albumTitle);
ps.setLong(4, album.marketingBudget);
assertInsertAlbumParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
assertEquals(rs.getLong(1), album.singerId);
assertEquals(rs.getLong(2), album.albumId);
assertEquals(rs.getString(3), album.albumTitle);
assertEquals(rs.getLong(4), album.marketingBudget);
}
// check that calling executeQuery will not reset the metadata
assertInsertAlbumParameterMetadata(ps.getParameterMetaData());
}
}
try (PreparedStatement ps =
connection.prepareStatement(getSongsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 6);
for (Song song : createSongs()) {
ps.setByte(1, (byte) song.singerId);
ps.setInt(2, (int) song.albumId);
ps.setShort(3, (short) song.songId);
ps.setNString(4, song.songName);
ps.setLong(5, song.duration);
ps.setCharacterStream(6, new StringReader(song.songGenre));
assertInsertSongParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
assertEquals(rs.getByte(1), (byte) song.singerId);
assertEquals(rs.getInt(2), (int) song.albumId);
assertEquals(rs.getShort(3), (short) song.songId);
assertEquals(rs.getNString(4), song.songName);
assertEquals(rs.getLong(5), song.duration);
assertEquals(
CharStreams.toString(rs.getCharacterStream(6)),
CharStreams.toString(new StringReader(song.songGenre)));
}
// check that calling executeQuery will not reset the metadata
assertInsertSongParameterMetadata(ps.getParameterMetaData());
}
} catch (IOException e) {
// ignore exception.
}
try (PreparedStatement ps =
connection.prepareStatement(getConcertsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(
ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect));
for (Concert concert : createConcerts()) {
concert.setPreparedStatement(connection, ps, getDialect());
assertInsertConcertParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
concert.assertEqualsFields(connection, rs, dialect.dialect);
}
// check that calling executeQuery will not reset the meta data
assertInsertConcertParameterMetadata(ps.getParameterMetaData());
}
}
connection.commit();
}
}

private List<String> readValuesFromFile(String filename) {
StringBuilder builder = new StringBuilder();
try (InputStream stream = ITJdbcPreparedStatementTest.class.getResourceAsStream(filename)) {
Expand Down

0 comments on commit 8a86467

Please sign in to comment.