diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcStatement.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcStatement.java index f37eb21a..051ed9d6 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcStatement.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcStatement.java @@ -87,7 +87,7 @@ public long executeLargeUpdate(String sql) throws SQLException { switch (result.getResultType()) { case RESULT_SET: throw JdbcSqlExceptionFactory.of( - "The statement is not an update or DDL statement", Code.INVALID_ARGUMENT); + "The statement is not a non-returning DML or DDL statement", Code.INVALID_ARGUMENT); case UPDATE_COUNT: return result.getUpdateCount(); case NO_RESULT: diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTest.java index 3ee9edcb..c7b6a6cb 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTest.java @@ -17,6 +17,12 @@ package com.google.cloud.spanner.jdbc; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.anyList; import static org.mockito.Mockito.mock; @@ -53,6 +59,8 @@ public class JdbcStatementTest { private static final String SELECT = "SELECT 1"; private static final String UPDATE = "UPDATE FOO SET BAR=1 WHERE BAZ=2"; private static final String LARGE_UPDATE = "UPDATE FOO SET BAR=1 WHERE 1=1"; + private static final String DML_RETURNING_GSQL = "UPDATE FOO SET BAR=1 WHERE 1=1 THEN RETURN *"; + private static final String DML_RETURNING_PG = "UPDATE FOO SET BAR=1 WHERE 1=1 RETURNING *"; private static final String DDL = "CREATE INDEX FOO ON BAR(ID)"; @Parameter public Dialect dialect; @@ -62,11 +70,20 @@ public static Object[] data() { return Dialect.values(); } + private String getDmlReturningSql() { + if (dialect == Dialect.GOOGLE_STANDARD_SQL) { + return DML_RETURNING_GSQL; + } + return DML_RETURNING_PG; + } + @SuppressWarnings("unchecked") private JdbcStatement createStatement() throws SQLException { Connection spanner = mock(Connection.class); when(spanner.getDialect()).thenReturn(dialect); + final String DML_RETURNING_SQL = getDmlReturningSql(); + com.google.cloud.spanner.ResultSet resultSet = mock(com.google.cloud.spanner.ResultSet.class); when(resultSet.next()).thenReturn(true, false); when(resultSet.getColumnType(0)).thenReturn(Type.int64()); @@ -88,6 +105,19 @@ private JdbcStatement createStatement() throws SQLException { when(spanner.execute(com.google.cloud.spanner.Statement.of(LARGE_UPDATE))) .thenReturn(largeUpdateResult); + com.google.cloud.spanner.ResultSet dmlReturningResultSet = + mock(com.google.cloud.spanner.ResultSet.class); + when(dmlReturningResultSet.next()).thenReturn(true, false); + when(dmlReturningResultSet.getColumnCount()).thenReturn(1); + when(dmlReturningResultSet.getColumnType(0)).thenReturn(Type.int64()); + when(dmlReturningResultSet.getLong(0)).thenReturn(1L); + + StatementResult dmlReturningResult = mock(StatementResult.class); + when(dmlReturningResult.getResultType()).thenReturn(ResultType.RESULT_SET); + when(dmlReturningResult.getResultSet()).thenReturn(dmlReturningResultSet); + when(spanner.execute(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL))) + .thenReturn(dmlReturningResult); + StatementResult ddlResult = mock(StatementResult.class); when(ddlResult.getResultType()).thenReturn(ResultType.NO_RESULT); when(spanner.execute(com.google.cloud.spanner.Statement.of(DDL))).thenReturn(ddlResult); @@ -96,6 +126,8 @@ private JdbcStatement createStatement() throws SQLException { when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(UPDATE))) .thenThrow( SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query")); + when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL))) + .thenReturn(dmlReturningResultSet); when(spanner.executeQuery(com.google.cloud.spanner.Statement.of(DDL))) .thenThrow( SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "not a query")); @@ -109,6 +141,10 @@ private JdbcStatement createStatement() throws SQLException { .thenThrow( SpannerExceptionFactory.newSpannerException( ErrorCode.INVALID_ARGUMENT, "not an update")); + when(spanner.executeUpdate(com.google.cloud.spanner.Statement.of(DML_RETURNING_SQL))) + .thenThrow( + SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, "cannot execute dml returning over executeUpdate")); when(spanner.executeBatchUpdate(anyList())) .thenAnswer( @@ -219,6 +255,20 @@ public void testExecuteWithDdlStatement() throws SQLException { assertThat(statement.getUpdateCount()).isEqualTo(JdbcConstants.STATEMENT_NO_RESULT); } + @Test + public void testExecuteWithDmlReturningStatement() throws SQLException { + Statement statement = createStatement(); + boolean res = statement.execute(getDmlReturningSql()); + assertTrue(res); + assertEquals(statement.getUpdateCount(), JdbcConstants.STATEMENT_RESULT_SET); + try (ResultSet rs = statement.getResultSet()) { + assertNotNull(rs); + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1L); + assertFalse(rs.next()); + } + } + @Test public void testExecuteWithGeneratedKeys() throws SQLException { Statement statement = createStatement(); @@ -257,6 +307,17 @@ public void testExecuteQueryWithUpdateStatement() { } } + @Test + public void testExecuteQueryWithDmlReturningStatement() throws SQLException { + Statement statement = createStatement(); + try (ResultSet rs = statement.executeQuery(getDmlReturningSql())) { + assertNotNull(rs); + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1L); + assertFalse(rs.next()); + } + } + @Test public void testExecuteQueryWithDdlStatement() { try { @@ -353,12 +414,29 @@ public void testExecuteUpdateWithSelectStatement() { } catch (SQLException e) { assertThat( JdbcExceptionMatcher.matchCodeAndMessage( - Code.INVALID_ARGUMENT, "The statement is not an update or DDL statement") + Code.INVALID_ARGUMENT, + "The statement is not a non-returning DML or DDL statement") .matches(e)) .isTrue(); } } + @Test + public void testExecuteUpdateWithDmlReturningStatement() { + try { + Statement statement = createStatement(); + SQLException e = + assertThrows(SQLException.class, () -> statement.executeUpdate(getDmlReturningSql())); + assertTrue( + JdbcExceptionMatcher.matchCodeAndMessage( + Code.INVALID_ARGUMENT, + "The statement is not a non-returning DML or DDL statement") + .matches(e)); + } catch (SQLException e) { + // ignore exception. + } + } + @Test public void testExecuteUpdateWithDdlStatement() throws SQLException { Statement statement = createStatement(); @@ -438,6 +516,19 @@ public void testDmlBatch() throws SQLException { } } + @Test + public void testDmlBatchWithDmlReturning() throws SQLException { + try (Statement statement = createStatement()) { + // Verify that multiple batches can be executed on the same statement. + for (int i = 0; i < 2; i++) { + statement.addBatch(getDmlReturningSql()); + statement.addBatch(getDmlReturningSql()); + statement.addBatch(getDmlReturningSql()); + assertArrayEquals(statement.executeBatch(), new int[] {1, 1, 1}); + } + } + } + @Test public void testLargeDmlBatch() throws SQLException { try (Statement statement = createStatement()) { diff --git a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java index a452fdb0..c1d84169 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java @@ -36,6 +36,7 @@ import com.google.cloud.spanner.testing.EmulatorSpannerHelper; import com.google.common.base.Strings; import com.google.common.io.BaseEncoding; +import com.google.common.io.CharStreams; import java.io.IOException; import java.io.InputStream; import java.io.StringReader; @@ -263,6 +264,24 @@ private void setPreparedStatement(Connection connection, PreparedStatement ps, D ps.setArray(6, connection.createArrayOf("INT64", this.ticketPrices)); } } + + private void assertEqualsFields(Connection connection, ResultSet rs, Dialect dialect) + throws SQLException { + assertEquals(rs.getLong(1), this.venueId); + assertEquals(rs.getLong(2), this.singerId); + if (dialect == Dialect.POSTGRESQL) { + assertEquals(rs.getString(3), this.concertDate.toString()); + assertEquals(rs.getString(4), this.beginTime.toString()); + assertEquals(rs.getString(5), this.endTime.toString()); + } else { + assertEquals(rs.getDate(3), this.concertDate); + assertEquals(rs.getTimestamp(4), this.beginTime); + assertEquals(rs.getTimestamp(5), this.endTime); + assertArrayEquals( + (Object[]) rs.getArray(6).getArray(), + (Object[]) connection.createArrayOf("INT64", this.ticketPrices).getArray()); + } + } } private static Date parseDate(String value) { @@ -333,6 +352,34 @@ private String getConcertsInsertQuery(Dialect dialect) { return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);"; } + private String getConcertsInsertReturningQuery(Dialect dialect) { + if (dialect == Dialect.POSTGRESQL) { + return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime) VALUES (?,?,?,?,?) RETURNING *;"; + } + return "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?) THEN RETURN *;"; + } + + private String getSingersInsertReturningQuery(Dialect dialect) { + if (dialect == Dialect.POSTGRESQL) { + return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) RETURNING *"; + } + return "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?) THEN RETURN *"; + } + + private String getAlbumsInsertReturningQuery(Dialect dialect) { + if (dialect == Dialect.POSTGRESQL) { + return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) RETURNING *"; + } + return "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?) THEN RETURN *"; + } + + private String getSongsInsertReturningQuery(Dialect dialect) { + if (dialect == Dialect.POSTGRESQL) { + return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) RETURNING *;"; + } + return "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?) THEN RETURN *;"; + } + private int getConcertExpectedParamCount(Dialect dialect) { if (dialect == Dialect.POSTGRESQL) { return 5; @@ -1150,6 +1197,103 @@ private void assertDefaultParameterMetaData(ParameterMetaData pmd, int expectedP } } + @Test + public void test12_InsertReturningTestData() throws SQLException { + assumeFalse( + "Emulator does not support DML with returning clause", + EmulatorSpannerHelper.isUsingEmulator()); + try (Connection connection = createConnection(env, database)) { + connection.setAutoCommit(false); + // Delete existing rows from tables populated by other tests, + // so that this test can populate rows from scratch. + Statement deleteStatements = connection.createStatement(); + deleteStatements.addBatch("DELETE FROM Concerts WHERE TRUE"); + deleteStatements.addBatch("DELETE FROM Songs WHERE TRUE"); + deleteStatements.addBatch("DELETE FROM Albums WHERE TRUE"); + deleteStatements.addBatch("DELETE FROM Singers WHERE TRUE"); + deleteStatements.executeBatch(); + try (PreparedStatement ps = + connection.prepareStatement(getSingersInsertReturningQuery(dialect.dialect))) { + assertDefaultParameterMetaData(ps.getParameterMetaData(), 5); + for (Singer singer : createSingers()) { + singer.setPreparedStatement(ps, getDialect()); + assertInsertSingerParameterMetadata(ps.getParameterMetaData()); + ps.addBatch(); + // check that adding the current params to a batch will not reset the metadata + assertInsertSingerParameterMetadata(ps.getParameterMetaData()); + } + int[] results = ps.executeBatch(); + for (int res : results) { + assertEquals(1, res); + } + } + try (PreparedStatement ps = + connection.prepareStatement(getAlbumsInsertReturningQuery(dialect.dialect))) { + assertDefaultParameterMetaData(ps.getParameterMetaData(), 4); + for (Album album : createAlbums()) { + ps.setLong(1, album.singerId); + ps.setLong(2, album.albumId); + ps.setString(3, album.albumTitle); + ps.setLong(4, album.marketingBudget); + assertInsertAlbumParameterMetadata(ps.getParameterMetaData()); + try (ResultSet rs = ps.executeQuery()) { + rs.next(); + assertEquals(rs.getLong(1), album.singerId); + assertEquals(rs.getLong(2), album.albumId); + assertEquals(rs.getString(3), album.albumTitle); + assertEquals(rs.getLong(4), album.marketingBudget); + } + // check that calling executeQuery will not reset the metadata + assertInsertAlbumParameterMetadata(ps.getParameterMetaData()); + } + } + try (PreparedStatement ps = + connection.prepareStatement(getSongsInsertReturningQuery(dialect.dialect))) { + assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + for (Song song : createSongs()) { + ps.setByte(1, (byte) song.singerId); + ps.setInt(2, (int) song.albumId); + ps.setShort(3, (short) song.songId); + ps.setNString(4, song.songName); + ps.setLong(5, song.duration); + ps.setCharacterStream(6, new StringReader(song.songGenre)); + assertInsertSongParameterMetadata(ps.getParameterMetaData()); + try (ResultSet rs = ps.executeQuery()) { + rs.next(); + assertEquals(rs.getByte(1), (byte) song.singerId); + assertEquals(rs.getInt(2), (int) song.albumId); + assertEquals(rs.getShort(3), (short) song.songId); + assertEquals(rs.getNString(4), song.songName); + assertEquals(rs.getLong(5), song.duration); + assertEquals( + CharStreams.toString(rs.getCharacterStream(6)), + CharStreams.toString(new StringReader(song.songGenre))); + } + // check that calling executeQuery will not reset the metadata + assertInsertSongParameterMetadata(ps.getParameterMetaData()); + } + } catch (IOException e) { + // ignore exception. + } + try (PreparedStatement ps = + connection.prepareStatement(getConcertsInsertReturningQuery(dialect.dialect))) { + assertDefaultParameterMetaData( + ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect)); + for (Concert concert : createConcerts()) { + concert.setPreparedStatement(connection, ps, getDialect()); + assertInsertConcertParameterMetadata(ps.getParameterMetaData()); + try (ResultSet rs = ps.executeQuery()) { + rs.next(); + concert.assertEqualsFields(connection, rs, dialect.dialect); + } + // check that calling executeQuery will not reset the meta data + assertInsertConcertParameterMetadata(ps.getParameterMetaData()); + } + } + connection.commit(); + } + } + private List readValuesFromFile(String filename) { StringBuilder builder = new StringBuilder(); try (InputStream stream = ITJdbcPreparedStatementTest.class.getResourceAsStream(filename)) {