Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tests for DML with Returning clause #936

Merged
Expand Up @@ -87,7 +87,7 @@ public long executeLargeUpdate(String sql) throws SQLException {
switch (result.getResultType()) {
case RESULT_SET:
throw JdbcSqlExceptionFactory.of(
"The statement is not an update or DDL statement", Code.INVALID_ARGUMENT);
"The statement is not a non-returning DML or DDL statement", Code.INVALID_ARGUMENT);
case UPDATE_COUNT:
return result.getUpdateCount();
case NO_RESULT:
Expand Down
Expand Up @@ -17,6 +17,12 @@
package com.google.cloud.spanner.jdbc;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -53,6 +59,8 @@ public class JdbcStatementTest {
private static final String SELECT = "SELECT 1";
private static final String UPDATE = "UPDATE FOO SET BAR=1 WHERE BAZ=2";
private static final String LARGE_UPDATE = "UPDATE FOO SET BAR=1 WHERE 1=1";
private static final String DML_RETURNING_GSQL = "UPDATE FOO SET BAR=1 WHERE 1=1 THEN RETURN *";
private static final String DML_RETURNING_PG = "UPDATE FOO SET BAR=1 WHERE 1=1 RETURNING *";
private static final String DDL = "CREATE INDEX FOO ON BAR(ID)";

@Parameter public Dialect dialect;
Expand All @@ -62,11 +70,20 @@ public static Object[] data() {
return Dialect.values();
}

private String getDmlReturningSql() {
if (dialect == Dialect.GOOGLE_STANDARD_SQL) {
return DML_RETURNING_GSQL;
}
return DML_RETURNING_PG;
}

@SuppressWarnings("unchecked")
private JdbcStatement createStatement() throws SQLException {
Connection spanner = mock(Connection.class);
when(spanner.getDialect()).thenReturn(dialect);

final String DML_RETURNING_SQL = getDmlReturningSql();

com.google.cloud.spanner.ResultSet resultSet = mock(com.google.cloud.spanner.ResultSet.class);
when(resultSet.next()).thenReturn(true, false);
when(resultSet.getColumnType(0)).thenReturn(Type.int64());
Expand All @@ -88,6 +105,19 @@ private JdbcStatement createStatement() throws SQLException {
when(spanner.execute(com.google.cloud.spanner.Statement.of(LARGE_UPDATE)))
.thenReturn(largeUpdateResult);

com.google.cloud.spanner.ResultSet dmlReturningResultSet =
mock(com.google.cloud.spanner.ResultSet.class);
when(dmlReturningResultSet.next()).thenReturn(true, false);
when(dmlReturningResultSet.getColumnCount()).thenReturn(1);
when(dmlReturningResultSet.getColumnType(0)).thenReturn(Type.int64());
when(dmlReturningResultSet.getLong(0)).thenReturn(1L);

StatementResult dmlReturningResult = mock(StatementResult.class);
when(dmlReturningResult.getResultType()).thenReturn(ResultType.RESULT_SET);
when(dmlReturningResult.getResultSet()).thenReturn(dmlReturningResultSet);
when(spanner.execute(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenReturn(dmlReturningResult);

StatementResult ddlResult = mock(StatementResult.class);
when(ddlResult.getResultType()).thenReturn(ResultType.NO_RESULT);
when(spanner.execute(com.google.cloud.spanner.Statement.of(DDL))).thenReturn(ddlResult);
Expand All @@ -96,6 +126,8 @@ private JdbcStatement createStatement() throws SQLException {
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(UPDATE)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query"));
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenReturn(dmlReturningResultSet);
when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DDL)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query"));
Expand All @@ -109,6 +141,10 @@ private JdbcStatement createStatement() throws SQLException {
.thenThrow(
SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT, "not an update"));
when(spanner.executeUpdate(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(
ErrorCode.FAILED_PRECONDITION, "cannot execute dml returning over executeUpdate"));

when(spanner.executeBatchUpdate(anyList()))
.thenAnswer(
Expand Down Expand Up @@ -219,6 +255,20 @@ public void testExecuteWithDdlStatement() throws SQLException {
assertThat(statement.getUpdateCount()).isEqualTo(JdbcConstants.STATEMENT_NO_RESULT);
}

@Test
public void testExecuteWithDmlReturningStatement() throws SQLException {
Statement statement = createStatement();
boolean res = statement.execute(getDmlReturningSql());
assertTrue(res);
assertEquals(statement.getUpdateCount(), JdbcConstants.STATEMENT_RESULT_SET);
try (ResultSet rs = statement.getResultSet()) {
assertNotNull(rs);
assertTrue(rs.next());
assertEquals(rs.getLong(1), 1L);
assertFalse(rs.next());
}
}

@Test
public void testExecuteWithGeneratedKeys() throws SQLException {
Statement statement = createStatement();
Expand Down Expand Up @@ -257,6 +307,17 @@ public void testExecuteQueryWithUpdateStatement() {
}
}

@Test
public void testExecuteQueryWithDmlReturningStatement() throws SQLException {
Statement statement = createStatement();
try (ResultSet rs = statement.executeQuery(getDmlReturningSql())) {
assertNotNull(rs);
assertTrue(rs.next());
assertEquals(rs.getLong(1), 1L);
assertFalse(rs.next());
}
}

@Test
public void testExecuteQueryWithDdlStatement() {
try {
Expand Down Expand Up @@ -353,12 +414,29 @@ public void testExecuteUpdateWithSelectStatement() {
} catch (SQLException e) {
assertThat(
JdbcExceptionMatcher.matchCodeAndMessage(
Code.INVALID_ARGUMENT, "The statement is not an update or DDL statement")
Code.INVALID_ARGUMENT,
"The statement is not a non-returning DML or DDL statement")
.matches(e))
.isTrue();
}
}

@Test
public void testExecuteUpdateWithDmlReturningStatement() {
try {
Statement statement = createStatement();
SQLException e =
assertThrows(SQLException.class, () -> statement.executeUpdate(getDmlReturningSql()));
assertTrue(
JdbcExceptionMatcher.matchCodeAndMessage(
Code.INVALID_ARGUMENT,
"The statement is not a non-returning DML or DDL statement")
.matches(e));
} catch (SQLException e) {
// ignore exception.
}
}

@Test
public void testExecuteUpdateWithDdlStatement() throws SQLException {
Statement statement = createStatement();
Expand Down Expand Up @@ -438,6 +516,19 @@ public void testDmlBatch() throws SQLException {
}
}

@Test
public void testDmlBatchWithDmlReturning() throws SQLException {
try (Statement statement = createStatement()) {
// Verify that multiple batches can be executed on the same statement.
for (int i = 0; i < 2; i++) {
statement.addBatch(getDmlReturningSql());
statement.addBatch(getDmlReturningSql());
statement.addBatch(getDmlReturningSql());
assertArrayEquals(statement.executeBatch(), new int[] {1, 1, 1});
}
}
}

@Test
public void testLargeDmlBatch() throws SQLException {
try (Statement statement = createStatement()) {
Expand Down
Expand Up @@ -35,6 +35,7 @@
import com.google.cloud.spanner.testing.EmulatorSpannerHelper;
import com.google.common.base.Strings;
import com.google.common.io.BaseEncoding;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
Expand Down Expand Up @@ -262,6 +263,24 @@ private void setPreparedStatement(Connection connection, PreparedStatement ps, D
ps.setArray(6, connection.createArrayOf("INT64", this.ticketPrices));
}
}

private void assertEqualsFields(Connection connection, ResultSet rs, Dialect dialect)
throws SQLException {
assertEquals(rs.getLong(1), this.venueId);
assertEquals(rs.getLong(2), this.singerId);
if (dialect == Dialect.POSTGRESQL) {
assertEquals(rs.getString(3), this.concertDate.toString());
assertEquals(rs.getString(4), this.beginTime.toString());
assertEquals(rs.getString(5), this.endTime.toString());
} else {
assertEquals(rs.getDate(3), this.concertDate);
assertEquals(rs.getTimestamp(4), this.beginTime);
assertEquals(rs.getTimestamp(5), this.endTime);
assertArrayEquals(
(Object[]) rs.getArray(6).getArray(),
(Object[]) connection.createArrayOf("INT64", this.ticketPrices).getArray());
}
}
}

private static Date parseDate(String value) {
Expand Down Expand Up @@ -332,6 +351,34 @@ private String getConcertsInsertQuery(Dialect dialect) {
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);";
}

private String getConcertsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime) VALUES (?,?,?,?,?) RETURNING *;";
}
return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?) THEN RETURN *;";
}

private String getSingersInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) RETURNING *";
}
return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) THEN RETURN *";
}

private String getAlbumsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) RETURNING *";
}
return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) THEN RETURN *";
}

private String getSongsInsertReturningQuery(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) RETURNING *;";
}
return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) THEN RETURN *;";
}

private int getConcertExpectedParamCount(Dialect dialect) {
if (dialect == Dialect.POSTGRESQL) {
return 5;
Expand Down Expand Up @@ -1104,6 +1151,95 @@ private void assertDefaultParameterMetaData(ParameterMetaData pmd, int expectedP
}
}

@Test
public void test12_InsertReturningTestData() throws SQLException {
assumeFalse(
"Emulator does not support DML with returning clause",
EmulatorSpannerHelper.isUsingEmulator());
try (Connection connection = createConnection(env, database)) {
connection.setAutoCommit(false);
try (PreparedStatement ps =
connection.prepareStatement(getSingersInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 5);
for (Singer singer : createSingers()) {
singer.setPreparedStatement(ps, getDialect());
assertInsertSingerParameterMetadata(ps.getParameterMetaData());
ps.addBatch();
// check that adding the current params to a batch will not reset the metadata
assertInsertSingerParameterMetadata(ps.getParameterMetaData());
}
int[] results = ps.executeBatch();
for (int res : results) {
assertEquals(1, res);
}
}
try (PreparedStatement ps =
connection.prepareStatement(getAlbumsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 4);
for (Album album : createAlbums()) {
ps.setLong(1, album.singerId);
ps.setLong(2, album.albumId);
ps.setString(3, album.albumTitle);
ps.setLong(4, album.marketingBudget);
assertInsertAlbumParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
assertEquals(rs.getLong(1), album.singerId);
assertEquals(rs.getLong(2), album.albumId);
assertEquals(rs.getString(3), album.albumTitle);
assertEquals(rs.getLong(4), album.marketingBudget);
}
// check that calling executeQuery will not reset the metadata
assertInsertAlbumParameterMetadata(ps.getParameterMetaData());
}
}
try (PreparedStatement ps =
connection.prepareStatement(getSongsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(ps.getParameterMetaData(), 6);
for (Song song : createSongs()) {
ps.setByte(1, (byte) song.singerId);
ps.setInt(2, (int) song.albumId);
ps.setShort(3, (short) song.songId);
ps.setNString(4, song.songName);
ps.setLong(5, song.duration);
ps.setCharacterStream(6, new StringReader(song.songGenre));
assertInsertSongParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
assertEquals(rs.getByte(1), (byte) song.singerId);
assertEquals(rs.getInt(2), (int) song.albumId);
assertEquals(rs.getShort(3), (short) song.songId);
assertEquals(rs.getNString(4), song.songName);
assertEquals(rs.getLong(5), song.duration);
assertEquals(
CharStreams.toString(rs.getCharacterStream(6)),
CharStreams.toString(new StringReader(song.songGenre)));
}
// check that calling executeQuery will not reset the metadata
assertInsertSongParameterMetadata(ps.getParameterMetaData());
}
} catch (IOException e) {
// ignore exception.
}
try (PreparedStatement ps =
connection.prepareStatement(getConcertsInsertReturningQuery(dialect.dialect))) {
assertDefaultParameterMetaData(
ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect));
for (Concert concert : createConcerts()) {
concert.setPreparedStatement(connection, ps, getDialect());
assertInsertConcertParameterMetadata(ps.getParameterMetaData());
try (ResultSet rs = ps.executeQuery()) {
rs.next();
concert.assertEqualsFields(connection, rs, dialect.dialect);
}
// check that calling executeQuery will not reset the meta data
assertInsertConcertParameterMetadata(ps.getParameterMetaData());
}
}
connection.commit();
}
}

private List<String> readValuesFromFile(String filename) {
StringBuilder builder = new StringBuilder();
try (InputStream stream = ITJdbcPreparedStatementTest.class.getResourceAsStream(filename)) {
Expand Down