diff --git a/pgjdbc/src/main/java/org/postgresql/core/TypeInfo.java b/pgjdbc/src/main/java/org/postgresql/core/TypeInfo.java index 9e77632a3a..0934a31104 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/TypeInfo.java +++ b/pgjdbc/src/main/java/org/postgresql/core/TypeInfo.java @@ -36,6 +36,8 @@ void addCoreType(String pgTypeName, Integer oid, Integer sqlType, String javaCla */ int getSQLType(String pgTypeName) throws SQLException; + int getJavaArrayType(String className) throws SQLException; + /** * Look up the oid for a given postgresql type name. This is the inverse of * {@link #getPGType(int)}. diff --git a/pgjdbc/src/main/java/org/postgresql/jdbc/PgPreparedStatement.java b/pgjdbc/src/main/java/org/postgresql/jdbc/PgPreparedStatement.java index 92e261d815..2531dc6277 100644 --- a/pgjdbc/src/main/java/org/postgresql/jdbc/PgPreparedStatement.java +++ b/pgjdbc/src/main/java/org/postgresql/jdbc/PgPreparedStatement.java @@ -739,18 +739,31 @@ public void setObject(@Positive int parameterIndex, @Nullable Object in, } } + private Class getArrayType(Class type) { + Class subType = type.getComponentType(); + while (subType != null) { + type = subType; + subType = type.getComponentType(); + } + return type; + } + private void setObjectArray(int parameterIndex, A in) throws SQLException { final ArrayEncoding.ArrayEncoder arraySupport = ArrayEncoding.getArrayEncoder(in); final TypeInfo typeInfo = connection.getTypeInfo(); - final int oid = arraySupport.getDefaultArrayTypeOid(); + int oid = arraySupport.getDefaultArrayTypeOid(); if (arraySupport.supportBinaryRepresentation(oid) && connection.getPreferQueryMode() != PreferQueryMode.SIMPLE) { bindBytes(parameterIndex, arraySupport.toBinaryRepresentation(connection, in, oid), oid); } else { if (oid == Oid.UNSPECIFIED) { - throw new SQLFeatureNotSupportedException(); + Class arrayType = getArrayType(in.getClass()); + oid = typeInfo.getJavaArrayType(arrayType.getName()); + if (oid == Oid.UNSPECIFIED) { + throw new SQLFeatureNotSupportedException(); + } } final int baseOid = typeInfo.getPGArrayElement(oid); final String baseType = castNonNull(typeInfo.getPGType(baseOid)); diff --git a/pgjdbc/src/main/java/org/postgresql/jdbc/TypeInfoCache.java b/pgjdbc/src/main/java/org/postgresql/jdbc/TypeInfoCache.java index 5f24171fc2..e946df2144 100644 --- a/pgjdbc/src/main/java/org/postgresql/jdbc/TypeInfoCache.java +++ b/pgjdbc/src/main/java/org/postgresql/jdbc/TypeInfoCache.java @@ -52,6 +52,8 @@ public class TypeInfoCache implements TypeInfo { // pgname (String) -> oid (Integer) private Map pgNameToOid; + private Map javaArrayTypeToOid; + // pgname (String) -> extension pgobject (Class) private Map> pgNameToPgObject; @@ -140,6 +142,7 @@ public TypeInfoCache(BaseConnection conn, int unknownLength) { this.unknownLength = unknownLength; oidToPgName = new HashMap((int) Math.round(types.length * 1.5)); pgNameToOid = new HashMap((int) Math.round(types.length * 1.5)); + javaArrayTypeToOid = new HashMap((int) Math.round(types.length * 1.5)); pgNameToJavaClass = new HashMap((int) Math.round(types.length * 1.5)); pgNameToPgObject = new HashMap>((int) Math.round(types.length * 1.5)); pgArrayToPgType = new HashMap((int) Math.round(types.length * 1.5)); @@ -168,6 +171,7 @@ public synchronized void addCoreType(String pgTypeName, Integer oid, Integer sql pgNameToJavaClass.put(pgTypeName, javaClass); pgNameToOid.put(pgTypeName, oid); oidToPgName.put(oid, pgTypeName); + javaArrayTypeToOid.put(javaClass, arrayOid); pgArrayToPgType.put(arrayOid, oid); pgNameToSQLType.put(pgTypeName, sqlType); oidToSQLType.put(oid, sqlType); @@ -319,6 +323,15 @@ public synchronized int getSQLType(String pgTypeName) throws SQLException { return i; } + @Override + public synchronized int getJavaArrayType(String className) throws SQLException { + Integer oid = javaArrayTypeToOid.get(className); + if (oid == null) { + return Oid.UNSPECIFIED; + } + return oid; + } + public synchronized int getSQLType(int typeOid) throws SQLException { if (typeOid == Oid.UNSPECIFIED) { return Types.OTHER; diff --git a/pgjdbc/src/test/java/org/postgresql/jdbc/ArraysTestSuite.java b/pgjdbc/src/test/java/org/postgresql/jdbc/ArraysTestSuite.java index a34c8e8ef6..0cb8395012 100644 --- a/pgjdbc/src/test/java/org/postgresql/jdbc/ArraysTestSuite.java +++ b/pgjdbc/src/test/java/org/postgresql/jdbc/ArraysTestSuite.java @@ -24,7 +24,8 @@ LongObjectArraysTest.class, ShortArraysTest.class, ShortObjectArraysTest.class, - StringArraysTest.class + StringArraysTest.class, + UUIDArrayTest.class }) public class ArraysTestSuite { } diff --git a/pgjdbc/src/test/java/org/postgresql/jdbc/UUIDArrayTest.java b/pgjdbc/src/test/java/org/postgresql/jdbc/UUIDArrayTest.java new file mode 100644 index 0000000000..a1d7984235 --- /dev/null +++ b/pgjdbc/src/test/java/org/postgresql/jdbc/UUIDArrayTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2022, PostgreSQL Global Development Group + * See the LICENSE file in the project root for more information. + */ + +package org.postgresql.jdbc; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.postgresql.core.ServerVersion; +import org.postgresql.test.TestUtil; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.UUID; + +class UUIDArrayTest { + + private static Connection con; + private static final String TABLE_NAME = "uuid_table"; + private static final String INSERT1 = "INSERT INTO " + TABLE_NAME + + " (id, data1) VALUES (?, ?)"; + private static final String INSERT2 = "INSERT INTO " + TABLE_NAME + + " (id, data2) VALUES (?, ?)"; + private static final String SELECT1 = "SELECT data1 FROM " + TABLE_NAME + + " WHERE id = ?"; + private static final String SELECT2 = "SELECT data2 FROM " + TABLE_NAME + + " WHERE id = ?"; + private static final UUID[] uids1 = new UUID[]{UUID.randomUUID(), UUID.randomUUID()}; + private static final UUID[][] uids2 = new UUID[][]{uids1}; + + @BeforeAll + public static void setUp() throws Exception { + con = TestUtil.openDB(); + assumeTrue(TestUtil.haveMinimumServerVersion(con, ServerVersion.v9_6)); + try (Statement stmt = con.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + + " (id int PRIMARY KEY, data1 UUID[], data2 UUID[][])"); + } + } + + @AfterAll + public static void tearDown() throws Exception { + try (Statement stmt = con.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + TestUtil.closeDB(con); + } + + @Test + void test1DWithCreateArrayOf() throws SQLException { + try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB()); + PreparedStatement stmt1 = c.prepareStatement(INSERT1); + PreparedStatement stmt2 = c.prepareStatement(SELECT1)) { + stmt1.setInt(1, 100); + stmt1.setArray(2, c.createArrayOf("uuid", uids1)); + stmt1.execute(); + + stmt2.setInt(1, 100); + stmt2.execute(); + try (ResultSet rs = stmt2.getResultSet()) { + assertTrue(rs.next()); + UUID[] array = (UUID[])rs.getArray(1).getArray(); + assertEquals(uids1[0], array[0]); + assertEquals(uids1[1], array[1]); + } + } + } + + @Test + void test1DWithSetObject() throws SQLException { + try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB()); + PreparedStatement stmt1 = c.prepareStatement(INSERT1); + PreparedStatement stmt2 = c.prepareStatement(SELECT1)) { + stmt1.setInt(1, 101); + stmt1.setObject(2, uids1); + stmt1.execute(); + + stmt2.setInt(1, 101); + stmt2.execute(); + try (ResultSet rs = stmt2.getResultSet()) { + assertTrue(rs.next()); + UUID[] array = (UUID[])rs.getArray(1).getArray(); + assertEquals(uids1[0], array[0]); + assertEquals(uids1[1], array[1]); + } + } + } + + @Test + void test2DWithCreateArrayOf() throws SQLException { + try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB()); + PreparedStatement stmt1 = c.prepareStatement(INSERT2); + PreparedStatement stmt2 = c.prepareStatement(SELECT2)) { + stmt1.setInt(1, 200); + stmt1.setArray(2, c.createArrayOf("uuid", uids2)); + stmt1.execute(); + + stmt2.setInt(1, 200); + stmt2.execute(); + try (ResultSet rs = stmt2.getResultSet()) { + assertTrue(rs.next()); + UUID[][] array = (UUID[][])rs.getArray(1).getArray(); + assertEquals(uids2[0][0], array[0][0]); + assertEquals(uids2[0][1], array[0][1]); + } + } + } + + @Test + void test2DWithSetObject() throws SQLException { + try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB()); + PreparedStatement stmt1 = c.prepareStatement(INSERT2); + PreparedStatement stmt2 = c.prepareStatement(SELECT2)) { + stmt1.setInt(1, 201); + stmt1.setObject(2, uids2); + stmt1.execute(); + + stmt2.setInt(1, 201); + stmt2.execute(); + try (ResultSet rs = stmt2.getResultSet()) { + assertTrue(rs.next()); + UUID[][] array = (UUID[][])rs.getArray(1).getArray(); + assertEquals(uids2[0][0], array[0][0]); + assertEquals(uids2[0][1], array[0][1]); + } + } + } +}