diff --git a/spring-context-support/src/main/java/org/springframework/scheduling/quartz/LocalDataSourceJobStore.java b/spring-context-support/src/main/java/org/springframework/scheduling/quartz/LocalDataSourceJobStore.java index d47fa28c0ea4..e0b12f443f91 100644 --- a/spring-context-support/src/main/java/org/springframework/scheduling/quartz/LocalDataSourceJobStore.java +++ b/spring-context-support/src/main/java/org/springframework/scheduling/quartz/LocalDataSourceJobStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.scheduling.quartz; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.SQLException; import javax.sql.DataSource; @@ -147,7 +148,8 @@ public void initialize() { // No, if HSQL is the platform, we really don't want to use locks... try { - String productName = JdbcUtils.extractDatabaseMetaData(this.dataSource, "getDatabaseProductName"); + String productName = JdbcUtils.extractDatabaseMetaData(this.dataSource, + DatabaseMetaData::getDatabaseProductName); productName = JdbcUtils.commonDatabaseName(productName); if (productName != null && productName.toLowerCase().contains("hsql")) { setUseDBLocks(false); diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/CallMetaDataProviderFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/CallMetaDataProviderFactory.java index ff93bfc9a859..2e8480b914c2 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/CallMetaDataProviderFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/CallMetaDataProviderFactory.java @@ -73,7 +73,7 @@ private CallMetaDataProviderFactory() { */ public static CallMetaDataProvider createMetaDataProvider(DataSource dataSource, final CallMetaDataContext context) { try { - return (CallMetaDataProvider) JdbcUtils.extractDatabaseMetaData(dataSource, databaseMetaData -> { + return JdbcUtils.extractDatabaseMetaData(dataSource, databaseMetaData -> { String databaseProductName = JdbcUtils.commonDatabaseName(databaseMetaData.getDatabaseProductName()); boolean accessProcedureColumnMetaData = context.isAccessCallParameterMetaData(); if (context.isFunction()) { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/TableMetaDataProviderFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/TableMetaDataProviderFactory.java index 8084c9897107..854d980cacd8 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/TableMetaDataProviderFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/metadata/TableMetaDataProviderFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,9 +49,8 @@ private TableMetaDataProviderFactory() { */ public static TableMetaDataProvider createMetaDataProvider(DataSource dataSource, TableMetaDataContext context) { try { - return (TableMetaDataProvider) JdbcUtils.extractDatabaseMetaData(dataSource, databaseMetaData -> { - String databaseProductName = - JdbcUtils.commonDatabaseName(databaseMetaData.getDatabaseProductName()); + return JdbcUtils.extractDatabaseMetaData(dataSource, databaseMetaData -> { + String databaseProductName = JdbcUtils.commonDatabaseName(databaseMetaData.getDatabaseProductName()); boolean accessTableColumnMetaData = context.isAccessTableColumnMetaData(); TableMetaDataProvider provider; diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/support/DatabaseMetaDataCallback.java b/spring-jdbc/src/main/java/org/springframework/jdbc/support/DatabaseMetaDataCallback.java index 701e3e172f5a..0f4197dc48f9 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/support/DatabaseMetaDataCallback.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/support/DatabaseMetaDataCallback.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,12 @@ * and handled correctly by the JdbcUtils class. * * @author Thomas Risberg - * @see JdbcUtils#extractDatabaseMetaData + * @author Juergen Hoeller + * @param the result type + * @see JdbcUtils#extractDatabaseMetaData(javax.sql.DataSource, DatabaseMetaDataCallback) */ @FunctionalInterface -public interface DatabaseMetaDataCallback { +public interface DatabaseMetaDataCallback { /** * Implementations must implement this method to process the meta-data @@ -42,6 +44,6 @@ public interface DatabaseMetaDataCallback { * @throws MetaDataAccessException in case of other failures while * extracting meta-data (for example, reflection failure) */ - Object processMetaData(DatabaseMetaData dbmd) throws SQLException, MetaDataAccessException; + T processMetaData(DatabaseMetaData dbmd) throws SQLException, MetaDataAccessException; } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/support/JdbcUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/support/JdbcUtils.java index c13a854c1259..6fe7e9ea125a 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/support/JdbcUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/support/JdbcUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -315,26 +315,44 @@ else if (obj instanceof java.sql.Date) { /** * Extract database meta-data via the given DatabaseMetaDataCallback. - *

This method will open a connection to the database and retrieve the database meta-data. - * Since this method is called before the exception translation feature is configured for - * a datasource, this method can not rely on the SQLException translation functionality. - *

Any exceptions will be wrapped in a MetaDataAccessException. This is a checked exception - * and any calling code should catch and handle this exception. You can just log the - * error and hope for the best, but there is probably a more serious error that will - * reappear when you try to access the database again. + *

This method will open a connection to the database and retrieve its meta-data. + * Since this method is called before the exception translation feature is configured + * for a DataSource, this method can not rely on SQLException translation itself. + *

Any exceptions will be wrapped in a MetaDataAccessException. This is a checked + * exception and any calling code should catch and handle this exception. You can just + * log the error and hope for the best, but there is probably a more serious error that + * will reappear when you try to access the database again. * @param dataSource the DataSource to extract meta-data for * @param action callback that will do the actual work * @return object containing the extracted information, as returned by * the DatabaseMetaDataCallback's {@code processMetaData} method * @throws MetaDataAccessException if meta-data access failed + * @see java.sql.DatabaseMetaData */ - public static Object extractDatabaseMetaData(DataSource dataSource, DatabaseMetaDataCallback action) + public static T extractDatabaseMetaData(DataSource dataSource, DatabaseMetaDataCallback action) throws MetaDataAccessException { Connection con = null; try { con = DataSourceUtils.getConnection(dataSource); - DatabaseMetaData metaData = con.getMetaData(); + DatabaseMetaData metaData; + try { + metaData = con.getMetaData(); + } + catch (SQLException ex) { + if (DataSourceUtils.isConnectionTransactional(con, dataSource)) { + // Probably a closed thread-bound Connection - retry against fresh Connection + DataSourceUtils.releaseConnection(con, dataSource); + con = null; + logger.debug("Failed to obtain DatabaseMetaData from transactional Connection - " + + "retrying against fresh Connection", ex); + con = dataSource.getConnection(); + metaData = con.getMetaData(); + } + else { + throw ex; + } + } if (metaData == null) { // should only happen in test environments throw new MetaDataAccessException("DatabaseMetaData returned by Connection [" + con + "] was null"); @@ -365,7 +383,11 @@ public static Object extractDatabaseMetaData(DataSource dataSource, DatabaseMeta * @throws MetaDataAccessException if we couldn't access the DatabaseMetaData * or failed to invoke the specified method * @see java.sql.DatabaseMetaData + * @deprecated as of 5.2.9, in favor of + * {@link #extractDatabaseMetaData(DataSource, DatabaseMetaDataCallback)} + * with a lambda expression or method reference and a generically typed result */ + @Deprecated @SuppressWarnings("unchecked") public static T extractDatabaseMetaData(DataSource dataSource, final String metaDataMethodName) throws MetaDataAccessException { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslator.java b/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslator.java index 2493f0b30949..f50380fcdec0 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslator.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ import org.springframework.jdbc.BadSqlGrammarException; import org.springframework.jdbc.InvalidResultSetAccessException; import org.springframework.lang.Nullable; +import org.springframework.util.function.SingletonSupplier; +import org.springframework.util.function.SupplierUtils; /** * Implementation of {@link SQLExceptionTranslator} that analyzes vendor-specific error codes. @@ -76,7 +78,7 @@ public class SQLErrorCodeSQLExceptionTranslator extends AbstractFallbackSQLExcep /** Error codes used by this translator. */ @Nullable - private SQLErrorCodes sqlErrorCodes; + private SingletonSupplier sqlErrorCodes; /** @@ -120,7 +122,7 @@ public SQLErrorCodeSQLExceptionTranslator(String dbName) { */ public SQLErrorCodeSQLExceptionTranslator(SQLErrorCodes sec) { this(); - this.sqlErrorCodes = sec; + this.sqlErrorCodes = SingletonSupplier.of(sec); } @@ -134,7 +136,9 @@ public SQLErrorCodeSQLExceptionTranslator(SQLErrorCodes sec) { * @see java.sql.DatabaseMetaData#getDatabaseProductName() */ public void setDataSource(DataSource dataSource) { - this.sqlErrorCodes = SQLErrorCodesFactory.getInstance().getErrorCodes(dataSource); + this.sqlErrorCodes = + SingletonSupplier.of(() -> SQLErrorCodesFactory.getInstance().resolveErrorCodes(dataSource)); + this.sqlErrorCodes.get(); // try early initialization - otherwise the supplier will retry later } /** @@ -146,7 +150,7 @@ public void setDataSource(DataSource dataSource) { * @see java.sql.DatabaseMetaData#getDatabaseProductName() */ public void setDatabaseProductName(String dbName) { - this.sqlErrorCodes = SQLErrorCodesFactory.getInstance().getErrorCodes(dbName); + this.sqlErrorCodes = SingletonSupplier.of(SQLErrorCodesFactory.getInstance().getErrorCodes(dbName)); } /** @@ -154,7 +158,7 @@ public void setDatabaseProductName(String dbName) { * @param sec custom error codes to use */ public void setSqlErrorCodes(@Nullable SQLErrorCodes sec) { - this.sqlErrorCodes = sec; + this.sqlErrorCodes = SingletonSupplier.ofNullable(sec); } /** @@ -164,7 +168,7 @@ public void setSqlErrorCodes(@Nullable SQLErrorCodes sec) { */ @Nullable public SQLErrorCodes getSqlErrorCodes() { - return this.sqlErrorCodes; + return SupplierUtils.resolve(this.sqlErrorCodes); } @@ -175,7 +179,6 @@ protected DataAccessException doTranslate(String task, @Nullable String sql, SQL if (sqlEx instanceof BatchUpdateException && sqlEx.getNextException() != null) { SQLException nestedSqlEx = sqlEx.getNextException(); if (nestedSqlEx.getErrorCode() > 0 || nestedSqlEx.getSQLState() != null) { - logger.debug("Using nested SQLException from the BatchUpdateException"); sqlEx = nestedSqlEx; } } @@ -187,8 +190,9 @@ protected DataAccessException doTranslate(String task, @Nullable String sql, SQL } // Next, try the custom SQLException translator, if available. - if (this.sqlErrorCodes != null) { - SQLExceptionTranslator customTranslator = this.sqlErrorCodes.getCustomSqlExceptionTranslator(); + SQLErrorCodes sqlErrorCodes = getSqlErrorCodes(); + if (sqlErrorCodes != null) { + SQLExceptionTranslator customTranslator = sqlErrorCodes.getCustomSqlExceptionTranslator(); if (customTranslator != null) { DataAccessException customDex = customTranslator.translate(task, sql, sqlEx); if (customDex != null) { @@ -198,9 +202,9 @@ protected DataAccessException doTranslate(String task, @Nullable String sql, SQL } // Check SQLErrorCodes with corresponding error code, if available. - if (this.sqlErrorCodes != null) { + if (sqlErrorCodes != null) { String errorCode; - if (this.sqlErrorCodes.isUseSqlStateForTranslation()) { + if (sqlErrorCodes.isUseSqlStateForTranslation()) { errorCode = sqlEx.getSQLState(); } else { @@ -215,7 +219,7 @@ protected DataAccessException doTranslate(String task, @Nullable String sql, SQL if (errorCode != null) { // Look for defined custom translations first. - CustomSQLErrorCodesTranslation[] customTranslations = this.sqlErrorCodes.getCustomTranslations(); + CustomSQLErrorCodesTranslation[] customTranslations = sqlErrorCodes.getCustomTranslations(); if (customTranslations != null) { for (CustomSQLErrorCodesTranslation customTranslation : customTranslations) { if (Arrays.binarySearch(customTranslation.getErrorCodes(), errorCode) >= 0 && @@ -230,43 +234,43 @@ protected DataAccessException doTranslate(String task, @Nullable String sql, SQL } } // Next, look for grouped error codes. - if (Arrays.binarySearch(this.sqlErrorCodes.getBadSqlGrammarCodes(), errorCode) >= 0) { + if (Arrays.binarySearch(sqlErrorCodes.getBadSqlGrammarCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new BadSqlGrammarException(task, (sql != null ? sql : ""), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getInvalidResultSetAccessCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getInvalidResultSetAccessCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new InvalidResultSetAccessException(task, (sql != null ? sql : ""), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getDuplicateKeyCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getDuplicateKeyCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new DuplicateKeyException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getDataIntegrityViolationCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getDataIntegrityViolationCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new DataIntegrityViolationException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getPermissionDeniedCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getPermissionDeniedCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new PermissionDeniedDataAccessException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getDataAccessResourceFailureCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getDataAccessResourceFailureCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new DataAccessResourceFailureException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getTransientDataAccessResourceCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getTransientDataAccessResourceCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new TransientDataAccessResourceException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getCannotAcquireLockCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getCannotAcquireLockCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new CannotAcquireLockException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getDeadlockLoserCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getDeadlockLoserCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new DeadlockLoserDataAccessException(buildMessage(task, sql, sqlEx), sqlEx); } - else if (Arrays.binarySearch(this.sqlErrorCodes.getCannotSerializeTransactionCodes(), errorCode) >= 0) { + else if (Arrays.binarySearch(sqlErrorCodes.getCannotSerializeTransactionCodes(), errorCode) >= 0) { logTranslation(task, sql, sqlEx, false); return new CannotSerializeTransactionException(buildMessage(task, sql, sqlEx), sqlEx); } @@ -276,7 +280,7 @@ else if (Arrays.binarySearch(this.sqlErrorCodes.getCannotSerializeTransactionCod // We couldn't identify it more precisely - let's hand it over to the SQLState fallback translator. if (logger.isDebugEnabled()) { String codes; - if (this.sqlErrorCodes != null && this.sqlErrorCodes.isUseSqlStateForTranslation()) { + if (sqlErrorCodes != null && sqlErrorCodes.isUseSqlStateForTranslation()) { codes = "SQL state '" + sqlEx.getSQLState() + "', error code '" + sqlEx.getErrorCode(); } else { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodesFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodesFactory.java index 74a9f1e6f851..2b78fa740f02 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodesFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/support/SQLErrorCodesFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.jdbc.support; +import java.sql.DatabaseMetaData; import java.util.Collections; import java.util.Map; @@ -159,6 +160,7 @@ protected Resource loadResource(String path) { *

No need for a database meta-data lookup. * @param databaseName the database name (must not be {@code null}) * @return the {@code SQLErrorCodes} instance for the given database + * (never {@code null}; potentially empty) * @throws IllegalArgumentException if the supplied database name is {@code null} */ public SQLErrorCodes getErrorCodes(String databaseName) { @@ -195,9 +197,27 @@ public SQLErrorCodes getErrorCodes(String databaseName) { * instance if no {@code SQLErrorCodes} were found. * @param dataSource the {@code DataSource} identifying the database * @return the corresponding {@code SQLErrorCodes} object + * (never {@code null}; potentially empty) * @see java.sql.DatabaseMetaData#getDatabaseProductName() */ public SQLErrorCodes getErrorCodes(DataSource dataSource) { + SQLErrorCodes sec = resolveErrorCodes(dataSource); + return (sec != null ? sec : new SQLErrorCodes()); + } + + /** + * Return {@link SQLErrorCodes} for the given {@link DataSource}, + * evaluating "databaseProductName" from the + * {@link java.sql.DatabaseMetaData}, or {@code null} if case + * of a JDBC meta-data access problem. + * @param dataSource the {@code DataSource} identifying the database + * @return the corresponding {@code SQLErrorCodes} object, + * or {@code null} in case of a JDBC meta-data access problem + * @since 5.2.9 + * @see java.sql.DatabaseMetaData#getDatabaseProductName() + */ + @Nullable + public SQLErrorCodes resolveErrorCodes(DataSource dataSource) { Assert.notNull(dataSource, "DataSource must not be null"); if (logger.isDebugEnabled()) { logger.debug("Looking up default SQLErrorCodes for DataSource [" + identify(dataSource) + "]"); @@ -212,16 +232,16 @@ public SQLErrorCodes getErrorCodes(DataSource dataSource) { if (sec == null) { // We could not find it - got to look it up. try { - String name = JdbcUtils.extractDatabaseMetaData(dataSource, "getDatabaseProductName"); + String name = JdbcUtils.extractDatabaseMetaData(dataSource, + DatabaseMetaData::getDatabaseProductName); if (StringUtils.hasLength(name)) { return registerDatabase(dataSource, name); } } catch (MetaDataAccessException ex) { - logger.warn("Error while extracting database name - falling back to empty error codes", ex); + logger.warn("Error while extracting database name", ex); } - // Fallback is to return an empty SQLErrorCodes instance. - return new SQLErrorCodes(); + return null; } } } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java index 17105ccf2742..2bff726ba3a3 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodeSQLExceptionTranslatorTests.java @@ -17,10 +17,15 @@ package org.springframework.jdbc.support; import java.sql.BatchUpdateException; +import java.sql.Connection; import java.sql.DataTruncation; +import java.sql.DatabaseMetaData; import java.sql.SQLException; +import javax.sql.DataSource; + import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import org.springframework.dao.CannotAcquireLockException; import org.springframework.dao.CannotSerializeTransactionException; @@ -35,6 +40,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * @author Rod Johnson @@ -79,7 +87,7 @@ public void errorCodeTranslation() { SQLException dupKeyEx = new SQLException("", "", 10); DataAccessException dksex = sext.translate("task", "SQL", dupKeyEx); - assertThat(DataIntegrityViolationException.class.isAssignableFrom(dksex.getClass())).as("Not instance of DataIntegrityViolationException").isTrue(); + assertThat(DataIntegrityViolationException.class.isInstance(dksex)).as("Not instance of DataIntegrityViolationException").isTrue(); // Test fallback. We assume that no database will ever return this error code, // but 07xxx will be bad grammar picked up by the fallback SQLState translator @@ -152,14 +160,13 @@ public void customExceptionTranslation() { final SQLErrorCodes customErrorCodes = new SQLErrorCodes(); final CustomSQLErrorCodesTranslation customTranslation = new CustomSQLErrorCodesTranslation(); - customErrorCodes.setBadSqlGrammarCodes(new String[] {"1", "2"}); - customErrorCodes.setDataIntegrityViolationCodes(new String[] {"3", "4"}); - customTranslation.setErrorCodes(new String[] {"1"}); + customErrorCodes.setBadSqlGrammarCodes("1", "2"); + customErrorCodes.setDataIntegrityViolationCodes("3", "4"); + customTranslation.setErrorCodes("1"); customTranslation.setExceptionClass(CustomErrorCodeException.class); - customErrorCodes.setCustomTranslations(new CustomSQLErrorCodesTranslation[] {customTranslation}); + customErrorCodes.setCustomTranslations(customTranslation); - SQLErrorCodeSQLExceptionTranslator sext = new SQLErrorCodeSQLExceptionTranslator(); - sext.setSqlErrorCodes(customErrorCodes); + SQLErrorCodeSQLExceptionTranslator sext = new SQLErrorCodeSQLExceptionTranslator(customErrorCodes); // Should custom translate this SQLException badSqlEx = new SQLException("", "", 1); @@ -176,4 +183,28 @@ public void customExceptionTranslation() { customTranslation.setExceptionClass(String.class)); } + @Test + public void dataSourceInitialization() throws Exception { + SQLException connectionException = new SQLException(); + SQLException duplicateKeyException = new SQLException("test", "", 1); + + DataSource dataSource = mock(DataSource.class); + given(dataSource.getConnection()).willThrow(connectionException); + + SQLErrorCodeSQLExceptionTranslator sext = new SQLErrorCodeSQLExceptionTranslator(dataSource); + assertThat(sext.translate("test", null, duplicateKeyException)).isNotInstanceOf(DuplicateKeyException.class); + + DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); + given(databaseMetaData.getDatabaseProductName()).willReturn("Oracle"); + + Connection connection = mock(Connection.class); + given(connection.getMetaData()).willReturn(databaseMetaData); + + Mockito.reset(dataSource); + given(dataSource.getConnection()).willReturn(connection); + assertThat(sext.translate("test", null, duplicateKeyException)).isInstanceOf(DuplicateKeyException.class); + + verify(connection).close(); + } + } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodesFactoryTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodesFactoryTests.java index f2d69fa14930..7ebf7267551a 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodesFactoryTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/support/SQLErrorCodesFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; /** @@ -39,6 +40,7 @@ * @author Rod Johnson * @author Thomas Risberg * @author Stephane Nicoll + * @author Juergen Hoeller */ public class SQLErrorCodesFactoryTests { @@ -239,7 +241,11 @@ public void testDataSourceWithNullMetadata() throws Exception { SQLErrorCodes sec = SQLErrorCodesFactory.getInstance().getErrorCodes(dataSource); assertIsEmpty(sec); + verify(connection).close(); + reset(connection); + sec = SQLErrorCodesFactory.getInstance().resolveErrorCodes(dataSource); + assertThat(sec).isNull(); verify(connection).close(); } @@ -252,12 +258,9 @@ public void testGetFromDataSourceWithSQLException() throws Exception { SQLErrorCodes sec = SQLErrorCodesFactory.getInstance().getErrorCodes(dataSource); assertIsEmpty(sec); - } - private void assertIsEmpty(SQLErrorCodes sec) { - // Codes should be empty - assertThat(sec.getBadSqlGrammarCodes().length).isEqualTo(0); - assertThat(sec.getDataIntegrityViolationCodes().length).isEqualTo(0); + sec = SQLErrorCodesFactory.getInstance().resolveErrorCodes(dataSource); + assertThat(sec).isNull(); } private SQLErrorCodes getErrorCodesFromDataSource(String productName, SQLErrorCodesFactory factory) throws Exception { @@ -270,17 +273,9 @@ private SQLErrorCodes getErrorCodesFromDataSource(String productName, SQLErrorCo DataSource dataSource = mock(DataSource.class); given(dataSource.getConnection()).willReturn(connection); - SQLErrorCodesFactory secf = null; - if (factory != null) { - secf = factory; - } - else { - secf = SQLErrorCodesFactory.getInstance(); - } - + SQLErrorCodesFactory secf = (factory != null ? factory : SQLErrorCodesFactory.getInstance()); SQLErrorCodes sec = secf.getErrorCodes(dataSource); - SQLErrorCodes sec2 = secf.getErrorCodes(dataSource); assertThat(sec).as("Cached per DataSource").isSameAs(sec2); @@ -375,4 +370,9 @@ protected Resource loadResource(String path) { assertIsEmpty(sec); } + private void assertIsEmpty(SQLErrorCodes sec) { + assertThat(sec.getBadSqlGrammarCodes().length).isEqualTo(0); + assertThat(sec.getDataIntegrityViolationCodes().length).isEqualTo(0); + } + }