Skip to content

Commit

Permalink
fix: make setObject accept UUID array (#2587)
Browse files Browse the repository at this point in the history
* fix: make setObject accept UUID array
  • Loading branch information
sasavilic committed Aug 12, 2022
1 parent 0b097fd commit 96f2561
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 3 deletions.
2 changes: 2 additions & 0 deletions pgjdbc/src/main/java/org/postgresql/core/TypeInfo.java
Expand Up @@ -36,6 +36,8 @@ void addCoreType(String pgTypeName, Integer oid, Integer sqlType, String javaCla
*/
int getSQLType(String pgTypeName) throws SQLException;

int getJavaArrayType(String className) throws SQLException;

/**
* Look up the oid for a given postgresql type name. This is the inverse of
* {@link #getPGType(int)}.
Expand Down
17 changes: 15 additions & 2 deletions pgjdbc/src/main/java/org/postgresql/jdbc/PgPreparedStatement.java
Expand Up @@ -739,18 +739,31 @@ public void setObject(@Positive int parameterIndex, @Nullable Object in,
}
}

private Class<?> getArrayType(Class<?> type) {
Class<?> subType = type.getComponentType();
while (subType != null) {
type = subType;
subType = type.getComponentType();
}
return type;
}

private <A extends @NonNull Object> void setObjectArray(int parameterIndex, A in) throws SQLException {
final ArrayEncoding.ArrayEncoder<A> arraySupport = ArrayEncoding.getArrayEncoder(in);

final TypeInfo typeInfo = connection.getTypeInfo();

final int oid = arraySupport.getDefaultArrayTypeOid();
int oid = arraySupport.getDefaultArrayTypeOid();

if (arraySupport.supportBinaryRepresentation(oid) && connection.getPreferQueryMode() != PreferQueryMode.SIMPLE) {
bindBytes(parameterIndex, arraySupport.toBinaryRepresentation(connection, in, oid), oid);
} else {
if (oid == Oid.UNSPECIFIED) {
throw new SQLFeatureNotSupportedException();
Class<?> arrayType = getArrayType(in.getClass());
oid = typeInfo.getJavaArrayType(arrayType.getName());
if (oid == Oid.UNSPECIFIED) {
throw new SQLFeatureNotSupportedException();
}
}
final int baseOid = typeInfo.getPGArrayElement(oid);
final String baseType = castNonNull(typeInfo.getPGType(baseOid));
Expand Down
13 changes: 13 additions & 0 deletions pgjdbc/src/main/java/org/postgresql/jdbc/TypeInfoCache.java
Expand Up @@ -52,6 +52,8 @@ public class TypeInfoCache implements TypeInfo {
// pgname (String) -> oid (Integer)
private Map<String, Integer> pgNameToOid;

private Map<String, Integer> javaArrayTypeToOid;

// pgname (String) -> extension pgobject (Class)
private Map<String, Class<? extends PGobject>> pgNameToPgObject;

Expand Down Expand Up @@ -140,6 +142,7 @@ public TypeInfoCache(BaseConnection conn, int unknownLength) {
this.unknownLength = unknownLength;
oidToPgName = new HashMap<Integer, String>((int) Math.round(types.length * 1.5));
pgNameToOid = new HashMap<String, Integer>((int) Math.round(types.length * 1.5));
javaArrayTypeToOid = new HashMap<String, Integer>((int) Math.round(types.length * 1.5));
pgNameToJavaClass = new HashMap<String, String>((int) Math.round(types.length * 1.5));
pgNameToPgObject = new HashMap<String, Class<? extends PGobject>>((int) Math.round(types.length * 1.5));
pgArrayToPgType = new HashMap<Integer, Integer>((int) Math.round(types.length * 1.5));
Expand Down Expand Up @@ -168,6 +171,7 @@ public synchronized void addCoreType(String pgTypeName, Integer oid, Integer sql
pgNameToJavaClass.put(pgTypeName, javaClass);
pgNameToOid.put(pgTypeName, oid);
oidToPgName.put(oid, pgTypeName);
javaArrayTypeToOid.put(javaClass, arrayOid);
pgArrayToPgType.put(arrayOid, oid);
pgNameToSQLType.put(pgTypeName, sqlType);
oidToSQLType.put(oid, sqlType);
Expand Down Expand Up @@ -319,6 +323,15 @@ public synchronized int getSQLType(String pgTypeName) throws SQLException {
return i;
}

@Override
public synchronized int getJavaArrayType(String className) throws SQLException {
Integer oid = javaArrayTypeToOid.get(className);
if (oid == null) {
return Oid.UNSPECIFIED;
}
return oid;
}

public synchronized int getSQLType(int typeOid) throws SQLException {
if (typeOid == Oid.UNSPECIFIED) {
return Types.OTHER;
Expand Down
Expand Up @@ -24,7 +24,8 @@
LongObjectArraysTest.class,
ShortArraysTest.class,
ShortObjectArraysTest.class,
StringArraysTest.class
StringArraysTest.class,
UUIDArrayTest.class
})
public class ArraysTestSuite {
}
139 changes: 139 additions & 0 deletions pgjdbc/src/test/java/org/postgresql/jdbc/UUIDArrayTest.java
@@ -0,0 +1,139 @@
/*
* Copyright (c) 2022, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/

package org.postgresql.jdbc;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import org.postgresql.core.ServerVersion;
import org.postgresql.test.TestUtil;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.UUID;

class UUIDArrayTest {

private static Connection con;
private static final String TABLE_NAME = "uuid_table";
private static final String INSERT1 = "INSERT INTO " + TABLE_NAME
+ " (id, data1) VALUES (?, ?)";
private static final String INSERT2 = "INSERT INTO " + TABLE_NAME
+ " (id, data2) VALUES (?, ?)";
private static final String SELECT1 = "SELECT data1 FROM " + TABLE_NAME
+ " WHERE id = ?";
private static final String SELECT2 = "SELECT data2 FROM " + TABLE_NAME
+ " WHERE id = ?";
private static final UUID[] uids1 = new UUID[]{UUID.randomUUID(), UUID.randomUUID()};
private static final UUID[][] uids2 = new UUID[][]{uids1};

@BeforeAll
public static void setUp() throws Exception {
con = TestUtil.openDB();
assumeTrue(TestUtil.haveMinimumServerVersion(con, ServerVersion.v9_6));
try (Statement stmt = con.createStatement()) {
stmt.execute("CREATE TABLE " + TABLE_NAME
+ " (id int PRIMARY KEY, data1 UUID[], data2 UUID[][])");
}
}

@AfterAll
public static void tearDown() throws Exception {
try (Statement stmt = con.createStatement()) {
stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME);
}
TestUtil.closeDB(con);
}

@Test
void test1DWithCreateArrayOf() throws SQLException {
try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB());
PreparedStatement stmt1 = c.prepareStatement(INSERT1);
PreparedStatement stmt2 = c.prepareStatement(SELECT1)) {
stmt1.setInt(1, 100);
stmt1.setArray(2, c.createArrayOf("uuid", uids1));
stmt1.execute();

stmt2.setInt(1, 100);
stmt2.execute();
try (ResultSet rs = stmt2.getResultSet()) {
assertTrue(rs.next());
UUID[] array = (UUID[])rs.getArray(1).getArray();
assertEquals(uids1[0], array[0]);
assertEquals(uids1[1], array[1]);
}
}
}

@Test
void test1DWithSetObject() throws SQLException {
try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB());
PreparedStatement stmt1 = c.prepareStatement(INSERT1);
PreparedStatement stmt2 = c.prepareStatement(SELECT1)) {
stmt1.setInt(1, 101);
stmt1.setObject(2, uids1);
stmt1.execute();

stmt2.setInt(1, 101);
stmt2.execute();
try (ResultSet rs = stmt2.getResultSet()) {
assertTrue(rs.next());
UUID[] array = (UUID[])rs.getArray(1).getArray();
assertEquals(uids1[0], array[0]);
assertEquals(uids1[1], array[1]);
}
}
}

@Test
void test2DWithCreateArrayOf() throws SQLException {
try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB());
PreparedStatement stmt1 = c.prepareStatement(INSERT2);
PreparedStatement stmt2 = c.prepareStatement(SELECT2)) {
stmt1.setInt(1, 200);
stmt1.setArray(2, c.createArrayOf("uuid", uids2));
stmt1.execute();

stmt2.setInt(1, 200);
stmt2.execute();
try (ResultSet rs = stmt2.getResultSet()) {
assertTrue(rs.next());
UUID[][] array = (UUID[][])rs.getArray(1).getArray();
assertEquals(uids2[0][0], array[0][0]);
assertEquals(uids2[0][1], array[0][1]);
}
}
}

@Test
void test2DWithSetObject() throws SQLException {
try (Connection c = assertDoesNotThrow(() -> TestUtil.openDB());
PreparedStatement stmt1 = c.prepareStatement(INSERT2);
PreparedStatement stmt2 = c.prepareStatement(SELECT2)) {
stmt1.setInt(1, 201);
stmt1.setObject(2, uids2);
stmt1.execute();

stmt2.setInt(1, 201);
stmt2.execute();
try (ResultSet rs = stmt2.getResultSet()) {
assertTrue(rs.next());
UUID[][] array = (UUID[][])rs.getArray(1).getArray();
assertEquals(uids2[0][0], array[0][0]);
assertEquals(uids2[0][1], array[0][1]);
}
}
}
}

0 comments on commit 96f2561

Please sign in to comment.