diff --git a/pgjdbc/src/main/java/org/postgresql/ssl/LazyKeyManager.java b/pgjdbc/src/main/java/org/postgresql/ssl/LazyKeyManager.java index 473ab9cb9e..a5392e0507 100644 --- a/pgjdbc/src/main/java/org/postgresql/ssl/LazyKeyManager.java +++ b/pgjdbc/src/main/java/org/postgresql/ssl/LazyKeyManager.java @@ -102,11 +102,27 @@ public void throwKeyManagerException() throws PSQLException { if (certchain == null) { return null; } else { - X500Principal ourissuer = certchain[certchain.length - 1].getIssuerX500Principal(); + X509Certificate cert = certchain[certchain.length - 1]; + X500Principal ourissuer = cert.getIssuerX500Principal(); + String certKeyType = cert.getPublicKey().getAlgorithm(); + boolean keyTypeFound = false; boolean found = false; - for (Principal issuer : issuers) { - if (ourissuer.equals(issuer)) { - found = true; + if (keyType != null && keyType.length > 0) { + for (String kt : keyType) { + if (kt.equalsIgnoreCase(certKeyType)) { + keyTypeFound = true; + } + } + } else { + // If no key types were passed in, assume we don't care + // about checking that the cert uses a particular key type. + keyTypeFound = true; + } + if (keyTypeFound) { + for (Principal issuer : issuers) { + if (ourissuer.equals(issuer)) { + found = keyTypeFound; + } } } return (found ? "user" : null); diff --git a/pgjdbc/src/main/java/org/postgresql/ssl/PKCS12KeyManager.java b/pgjdbc/src/main/java/org/postgresql/ssl/PKCS12KeyManager.java index 89d7383164..602e5702cf 100644 --- a/pgjdbc/src/main/java/org/postgresql/ssl/PKCS12KeyManager.java +++ b/pgjdbc/src/main/java/org/postgresql/ssl/PKCS12KeyManager.java @@ -67,7 +67,7 @@ public void throwKeyManagerException() throws PSQLException { } @Override - public @Nullable String chooseClientAlias(String[] strings, Principal @Nullable [] principals, + public @Nullable String chooseClientAlias(String[] keyType, Principal @Nullable [] principals, @Nullable Socket socket) { if (principals == null || principals.length == 0) { // Postgres 8.4 and earlier do not send the list of accepted certificate authorities @@ -81,11 +81,27 @@ public void throwKeyManagerException() throws PSQLException { if (certchain == null) { return null; } else { - X500Principal ourissuer = certchain[certchain.length - 1].getIssuerX500Principal(); + X509Certificate cert = certchain[certchain.length - 1]; + X500Principal ourissuer = cert.getIssuerX500Principal(); + String certKeyType = cert.getPublicKey().getAlgorithm(); + boolean keyTypeFound = false; boolean found = false; - for (Principal issuer : principals) { - if (ourissuer.equals(issuer)) { - found = true; + if (keyType != null && keyType.length > 0) { + for (String kt : keyType) { + if (kt.equalsIgnoreCase(certKeyType)) { + keyTypeFound = true; + } + } + } else { + // If no key types were passed in, assume we don't care + // about checking that the cert uses a particular key type. + keyTypeFound = true; + } + if (keyTypeFound) { + for (Principal issuer : principals) { + if (ourissuer.equals(issuer)) { + found = keyTypeFound; + } } } return (found ? "user" : null); diff --git a/pgjdbc/src/test/java/org/postgresql/test/ssl/LazyKeyManagerTest.java b/pgjdbc/src/test/java/org/postgresql/test/ssl/LazyKeyManagerTest.java index 87d3aca87e..c83bf1c0e6 100644 --- a/pgjdbc/src/test/java/org/postgresql/test/ssl/LazyKeyManagerTest.java +++ b/pgjdbc/src/test/java/org/postgresql/test/ssl/LazyKeyManagerTest.java @@ -20,6 +20,7 @@ import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.PasswordCallback; import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.x500.X500Principal; public class LazyKeyManagerTest { @@ -45,6 +46,32 @@ public void testLoadKey() throws Exception { Assert.assertNotNull(pk); } + @Test + public void testChooseClientAlias() throws Exception { + LazyKeyManager lazyKeyManager = new LazyKeyManager( + TestUtil.getSslTestCertPath("goodclient.crt"), + TestUtil.getSslTestCertPath("goodclient.pk8"), + new TestCallbackHandler("sslpwd"), + true); + X500Principal testPrincipal = new X500Principal("CN=root certificate, O=PgJdbc test, ST=CA, C=US"); + X500Principal[] issuers = new X500Principal[]{testPrincipal}; + + String validKeyType = lazyKeyManager.chooseClientAlias(new String[]{"RSA"}, issuers, null); + Assert.assertNotNull(validKeyType); + + String ignoresCase = lazyKeyManager.chooseClientAlias(new String[]{"rsa"}, issuers, null); + Assert.assertNotNull(ignoresCase); + + String invalidKeyType = lazyKeyManager.chooseClientAlias(new String[]{"EC"}, issuers, null); + Assert.assertNull(invalidKeyType); + + String containsValidKeyType = lazyKeyManager.chooseClientAlias(new String[]{"EC","RSA"}, issuers, null); + Assert.assertNotNull(containsValidKeyType); + + String ignoresBlank = lazyKeyManager.chooseClientAlias(new String[]{}, issuers, null); + Assert.assertNotNull(ignoresBlank); + } + public static class TestCallbackHandler implements CallbackHandler { char [] password; diff --git a/pgjdbc/src/test/java/org/postgresql/test/ssl/PKCS12KeyTest.java b/pgjdbc/src/test/java/org/postgresql/test/ssl/PKCS12KeyTest.java index c0696c947d..43935eebf6 100644 --- a/pgjdbc/src/test/java/org/postgresql/test/ssl/PKCS12KeyTest.java +++ b/pgjdbc/src/test/java/org/postgresql/test/ssl/PKCS12KeyTest.java @@ -6,14 +6,22 @@ package org.postgresql.test.ssl; import org.postgresql.PGProperty; +import org.postgresql.ssl.PKCS12KeyManager; import org.postgresql.test.TestUtil; import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.sql.Connection; import java.util.Properties; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.x500.X500Principal; + public class PKCS12KeyTest { @Test public void TestGoodClientP12() throws Exception { @@ -29,4 +37,50 @@ public void TestGoodClientP12() throws Exception { Assert.assertTrue("SSL should be in use", sslUsed); } } + + @Test + public void TestChooseClientAlias() throws Exception { + PKCS12KeyManager pkcs12KeyManager = new PKCS12KeyManager(TestUtil.getSslTestCertPath("goodclient.p12"), new TestCallbackHandler("sslpwd")); + X500Principal testPrincipal = new X500Principal("CN=root certificate, O=PgJdbc test, ST=CA, C=US"); + X500Principal[] issuers = new X500Principal[]{testPrincipal}; + + String validKeyType = pkcs12KeyManager.chooseClientAlias(new String[]{"RSA"}, issuers, null); + Assert.assertNotNull(validKeyType); + + String ignoresCase = pkcs12KeyManager.chooseClientAlias(new String[]{"rsa"}, issuers, null); + Assert.assertNotNull(ignoresCase); + + String invalidKeyType = pkcs12KeyManager.chooseClientAlias(new String[]{"EC"}, issuers, null); + Assert.assertNull(invalidKeyType); + + String containsValidKeyType = pkcs12KeyManager.chooseClientAlias(new String[]{"EC","RSA"}, issuers, null); + Assert.assertNotNull(containsValidKeyType); + + String ignoresBlank = pkcs12KeyManager.chooseClientAlias(new String[]{}, issuers, null); + Assert.assertNotNull(ignoresBlank); + } + + public static class TestCallbackHandler implements CallbackHandler { + char [] password; + + public TestCallbackHandler(String password) { + if (password != null) { + this.password = password.toCharArray(); + } + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (!(callback instanceof PasswordCallback)) { + throw new UnsupportedCallbackException(callback); + } + PasswordCallback pwdCallback = (PasswordCallback) callback; + if (password != null) { + pwdCallback.setPassword(password); + continue; + } + } + } + } }