Skip to content

Commit

Permalink
fix: DataSources broken by connection failover urls (#1039) (#1457)
Browse files Browse the repository at this point in the history
* fix DataSources broken by connection failover urls (#1039)

* removing java8 string join method and fixing indentation (#1039)

* preserve original test "bds" as we are modifying ours (#1039)

* storing old bds reference before changing it (#1039)

* moving the url-modifiying test out of BaseDataSourceTest (#1039)
  • Loading branch information
teicher authored and davecramer committed Sep 26, 2019
1 parent 831115c commit bd9485e
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 33 deletions.
160 changes: 128 additions & 32 deletions pgjdbc/src/main/java/org/postgresql/ds/common/BaseDataSource.java
Expand Up @@ -23,6 +23,7 @@
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Properties;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -39,16 +40,16 @@
*
* @author Aaron Mulder (ammulder@chariotsolutions.com)
*/
public abstract class BaseDataSource implements CommonDataSource, Referenceable {

public abstract class BaseDataSource implements CommonDataSource, Referenceable {
private static final Logger LOGGER = Logger.getLogger(BaseDataSource.class.getName());

// Standard properties, defined in the JDBC 2.0 Optional Package spec
private String serverName = "localhost";
private String[] serverNames = new String[] {"localhost"};
private String databaseName = "";
private String user;
private String password;
private int portNumber = 0;
private int[] portNumbers = new int[] {0};

// Map for all other properties
private Properties properties = new Properties();
Expand Down Expand Up @@ -129,22 +130,51 @@ public void setLogWriter(PrintWriter printWriter) {
* Gets the name of the host the PostgreSQL database is running on.
*
* @return name of the host the PostgreSQL database is running on
* @deprecated use {@link #getServerNames()}
*/
@Deprecated
public String getServerName() {
return serverName;
return serverNames[0];
}

/**
* Gets the name of the host(s) the PostgreSQL database is running on.
*
* @return name of the host(s) the PostgreSQL database is running on
*/
public String[] getServerNames() {
return serverNames;
}

/**
* Sets the name of the host the PostgreSQL database is running on. If this is changed, it will
* only affect future calls to getConnection. The default value is <tt>localhost</tt>.
*
* @param serverName name of the host the PostgreSQL database is running on
* @deprecated use {@link #setServerNames(String[])}
*/
@Deprecated
public void setServerName(String serverName) {
if (serverName == null || serverName.equals("")) {
this.serverName = "localhost";
this.setServerNames(new String[] { serverName });
}

/**
* Sets the name of the host(s) the PostgreSQL database is running on. If this is changed, it will
* only affect future calls to getConnection. The default value is <tt>localhost</tt>.
*
* @param serverNames name of the host(s) the PostgreSQL database is running on
*/
public void setServerNames(String[] serverNames) {
if (serverNames == null || serverNames.length == 0) {
this.serverNames = new String[] {"localhost"};
} else {
this.serverName = serverName;
serverNames = Arrays.copyOf(serverNames, serverNames.length);
for (int i = 0; i < serverNames.length; i++) {
if (serverNames[i] == null || serverNames[i].equals("")) {
serverNames[i] = "localhost";
}
}
this.serverNames = serverNames;
}
}

Expand Down Expand Up @@ -221,20 +251,50 @@ public void setPassword(String password) {
* Gets the port which the PostgreSQL server is listening on for TCP/IP connections.
*
* @return The port, or 0 if the default port will be used.
* @deprecated use {@link #getPortNumbers()}
*/
@Deprecated
public int getPortNumber() {
return portNumber;
if (portNumbers == null || portNumbers.length == 0) {
return 0;
}
return portNumbers[0];
}

/**
* Gets the port(s) which the PostgreSQL server is listening on for TCP/IP connections.
*
* @return The port(s), or 0 if the default port will be used.
*/
public int[] getPortNumbers() {
return portNumbers;
}

/**
* Gets the port which the PostgreSQL server is listening on for TCP/IP connections. Be sure the
* Sets the port which the PostgreSQL server is listening on for TCP/IP connections. Be sure the
* -i flag is passed to postmaster when PostgreSQL is started. If this is not set, or set to 0,
* the default port will be used.
*
* @param portNumber port which the PostgreSQL server is listening on for TCP/IP
* @deprecated use {@link #setPortNumbers(int[])}
*/
@Deprecated
public void setPortNumber(int portNumber) {
this.portNumber = portNumber;
setPortNumbers(new int[] { portNumber });
}

/**
* Sets the port(s) which the PostgreSQL server is listening on for TCP/IP connections. Be sure the
* -i flag is passed to postmaster when PostgreSQL is started. If this is not set, or set to 0,
* the default port will be used.
*
* @param portNumbers port(s) which the PostgreSQL server is listening on for TCP/IP
*/
public void setPortNumbers(int[] portNumbers) {
if (portNumbers == null || portNumbers.length == 0) {
portNumbers = new int[] { 0 };
}
this.portNumbers = Arrays.copyOf(portNumbers, portNumbers.length);
}

/**
Expand Down Expand Up @@ -1085,9 +1145,14 @@ public void setLoggerFile(String loggerFile) {
public String getUrl() {
StringBuilder url = new StringBuilder(100);
url.append("jdbc:postgresql://");
url.append(serverName);
if (portNumber != 0) {
url.append(":").append(portNumber);
for (int i = 0; i < serverNames.length; i++) {
if (i > 0) {
url.append(",");
}
url.append(serverNames[i]);
if (portNumbers != null && portNumbers.length >= i && portNumbers[i] != 0) {
url.append(":").append(portNumbers[i]);
}
}
url.append("/").append(URLCoder.encode(databaseName));

Expand Down Expand Up @@ -1179,23 +1244,28 @@ public void setProperty(PGProperty property, String value) {
}
switch (property) {
case PG_HOST:
serverName = value;
setServerNames(value.split(","));
break;
case PG_PORT:
try {
portNumber = Integer.parseInt(value);
} catch (NumberFormatException e) {
portNumber = 0;
String[] ps = value.split(",");
int[] ports = new int[ps.length];
for (int i = 0 ; i < ps.length; i++) {
try {
ports[i] = Integer.parseInt(ps[i]);
} catch (NumberFormatException e) {
ports[i] = 0;
}
}
setPortNumbers(ports);
break;
case PG_DBNAME:
databaseName = value;
setDatabaseName(value);
break;
case USER:
user = value;
setUser(value);
break;
case PASSWORD:
password = value;
setPassword(value);
break;
default:
properties.setProperty(property.getName(), value);
Expand All @@ -1213,10 +1283,25 @@ protected Reference createReference() {

public Reference getReference() throws NamingException {
Reference ref = createReference();
ref.add(new StringRefAddr("serverName", serverName));
if (portNumber != 0) {
ref.add(new StringRefAddr("portNumber", Integer.toString(portNumber)));
StringBuilder serverString = new StringBuilder();
for (int i = 0; i < serverNames.length; i++) {
if (i > 0) {
serverString.append(",");
}
String serverName = serverNames[i];
serverString.append(serverName);
}
ref.add(new StringRefAddr("serverName", serverString.toString()));

StringBuilder portString = new StringBuilder();
for (int i = 0; i < portNumbers.length; i++) {
if (i > 0) {
portString.append(",");
}
int p = portNumbers[i];
portString.append(Integer.toString(p));
}
ref.add(new StringRefAddr("portNumber", portString.toString()));
ref.add(new StringRefAddr("databaseName", databaseName));
if (user != null) {
ref.add(new StringRefAddr("user", user));
Expand All @@ -1236,11 +1321,22 @@ public Reference getReference() throws NamingException {

public void setFromReference(Reference ref) {
databaseName = getReferenceProperty(ref, "databaseName");
String port = getReferenceProperty(ref, "portNumber");
if (port != null) {
portNumber = Integer.parseInt(port);
String portNumberString = getReferenceProperty(ref, "portNumber");
if (portNumberString != null) {
String[] ps = portNumberString.split(",");
int[] ports = new int[ps.length];
for (int i = 0; i < ps.length; i++) {
try {
ports[i] = Integer.parseInt(ps[i]);
} catch (NumberFormatException e) {
ports[i] = 0;
}
}
setPortNumbers(ports);
} else {
setPortNumbers(null);
}
serverName = getReferenceProperty(ref, "serverName");
setServerNames(getReferenceProperty(ref, "serverName").split(","));

for (PGProperty property : PGProperty.values()) {
setProperty(property, getReferenceProperty(ref, property.getName()));
Expand All @@ -1256,21 +1352,21 @@ private static String getReferenceProperty(Reference ref, String propertyName) {
}

protected void writeBaseObject(ObjectOutputStream out) throws IOException {
out.writeObject(serverName);
out.writeObject(serverNames);
out.writeObject(databaseName);
out.writeObject(user);
out.writeObject(password);
out.writeInt(portNumber);
out.writeObject(portNumbers);

out.writeObject(properties);
}

protected void readBaseObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
serverName = (String) in.readObject();
serverNames = (String[]) in.readObject();
databaseName = (String) in.readObject();
user = (String) in.readObject();
password = (String) in.readObject();
portNumber = in.readInt();
portNumbers = (int[]) in.readObject();

properties = (Properties) in.readObject();
}
Expand Down
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2004, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/

package org.postgresql.test.jdbc2.optional;

import static org.junit.Assert.assertEquals;

import org.postgresql.ds.common.BaseDataSource;

import org.junit.Test;

import java.io.IOException;
import javax.naming.NamingException;

/**
* tests that failover urls survive the parse/rebuild roundtrip with and without specific ports
*/
public class BaseDataSourceFailoverUrlsTest {

private static final String DEFAULT_PORT = "5432";

@Test
public void testFullDefault() throws ClassNotFoundException, NamingException, IOException {
roundTripFromUrl("jdbc:postgresql://server/database", "jdbc:postgresql://server:" + DEFAULT_PORT + "/database");
}

@Test
public void testTwoNoPorts() throws ClassNotFoundException, NamingException, IOException {
roundTripFromUrl("jdbc:postgresql://server1,server2/database", "jdbc:postgresql://server1:" + DEFAULT_PORT + ",server2:" + DEFAULT_PORT + "/database");
}

@Test
public void testTwoWithPorts() throws ClassNotFoundException, NamingException, IOException {
roundTripFromUrl("jdbc:postgresql://server1:1234,server2:2345/database", "jdbc:postgresql://server1:1234,server2:2345/database");
}

@Test
public void testTwoFirstPort() throws ClassNotFoundException, NamingException, IOException {
roundTripFromUrl("jdbc:postgresql://server1,server2:2345/database", "jdbc:postgresql://server1:" + DEFAULT_PORT + ",server2:2345/database");
}

@Test
public void testTwoLastPort() throws ClassNotFoundException, NamingException, IOException {
roundTripFromUrl("jdbc:postgresql://server1:2345,server2/database", "jdbc:postgresql://server1:2345,server2:" + DEFAULT_PORT + "/database");
}

@Test
public void testNullPorts() {
BaseDataSource bds = newDS();
bds.setDatabaseName("database");
bds.setPortNumbers(null);
assertUrlWithoutParamsEquals("jdbc:postgresql://localhost/database", bds.getURL());
assertEquals(0, bds.getPortNumber());
assertEquals(0, bds.getPortNumbers()[0]);
}

@Test
public void testEmptyPorts() {
BaseDataSource bds = newDS();
bds.setDatabaseName("database");
bds.setPortNumbers(new int[0]);
assertUrlWithoutParamsEquals("jdbc:postgresql://localhost/database", bds.getURL());
assertEquals(0, bds.getPortNumber());
assertEquals(0, bds.getPortNumbers()[0]);
}

private BaseDataSource newDS() {
return new BaseDataSource() {
@Override
public String getDescription() {
return "BaseDataSourceFailoverUrlsTest-DS";
}
};
}

private void roundTripFromUrl(String in, String expected) throws NamingException, ClassNotFoundException, IOException {
BaseDataSource bds = newDS();

bds.setUrl(in);
assertUrlWithoutParamsEquals(expected, bds.getURL());

bds.setFromReference(bds.getReference());
assertUrlWithoutParamsEquals(expected, bds.getURL());

bds.initializeFrom(bds);
assertUrlWithoutParamsEquals(expected, bds.getURL());
}

private static String jdbcUrlStripParams(String in) {
return in.replaceAll("\\?.*$", "");
}

private static void assertUrlWithoutParamsEquals(String expected, String url) {
assertEquals(expected, jdbcUrlStripParams(url));
}
}
Expand Up @@ -20,7 +20,8 @@
SimpleDataSourceWithSetURLTest.class,
ConnectionPoolTest.class,
PoolingDataSourceTest.class,
CaseOptimiserDataSourceTest.class})
CaseOptimiserDataSourceTest.class,
BaseDataSourceFailoverUrlsTest.class})
public class OptionalTestSuite {

}

0 comments on commit bd9485e

Please sign in to comment.