From 9bf751a46c6366e0061132b33555741bd902043d Mon Sep 17 00:00:00 2001 From: MarkMarkyMarkus Date: Wed, 20 Sep 2023 18:06:56 +0300 Subject: [PATCH] Add Params interface to carry binding --- .../r2dbc/core/DatabaseClient.java | 56 ++++++ .../r2dbc/core/DefaultDatabaseClient.java | 180 +++++++++--------- .../r2dbc/core/DefaultParamsImpl.java | 130 +++++++++++++ ...bstractDatabaseClientIntegrationTests.java | 17 +- .../core/DefaultDatabaseClientUnitTests.java | 18 +- 5 files changed, 294 insertions(+), 107 deletions(-) create mode 100644 spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultParamsImpl.java diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java index 4ad0cd7c818a..12b3513b255d 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java @@ -22,8 +22,10 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Parameter; import io.r2dbc.spi.Readable; import io.r2dbc.spi.Result; import io.r2dbc.spi.Row; @@ -210,6 +212,13 @@ interface GenericExecuteSpec { */ GenericExecuteSpec bindProperties(Object source); + /** + * Bind values from the given mapping function. + * @param paramsFunction a function that maps from {@link Params} to {@link Params} with new bindings + * @since 6.1 + */ + GenericExecuteSpec bind(UnaryOperator paramsFunction); + /** * Add the given filter to the end of the filter chain. *

Filter functions are typically used to invoke methods on the Statement @@ -300,4 +309,51 @@ default GenericExecuteSpec filter(Function then(); } + /** + * Interface that defines common functionality for parameter binding. + */ + interface Params { + /** + * See {@link GenericExecuteSpec#bind(String, Object)}. + */ + Params bind(String name, Object value); + + /** + * See {@link GenericExecuteSpec#bind(int, Object)}. + */ + Params bind(int index, Object value); + + /** + * See {@link GenericExecuteSpec#bindNull(String, Class)}. + */ + Params bindNull(String name, Class type); + + /** + * See {@link GenericExecuteSpec#bind(int, Object)}. + */ + Params bindNull(int index, Class type); + + /** + * See {@link GenericExecuteSpec#bindValues(Map)}. + */ + Params bindValues(Map source); + + /** + * See {@link GenericExecuteSpec#bindProperties(Object)}. + */ + Params bindProperties(Object source); + + /** + * Get bindings by index. + * @return index based params + */ + Map byIndex(); + + /** + * Get bindings by name. + * @return name based params + */ + Map byName(); + } + } diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java index dde1f7f3a9d4..77340108a3ca 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java @@ -16,25 +16,24 @@ package org.springframework.r2dbc.core; -import java.beans.PropertyDescriptor; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; +import java.util.ArrayList; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import io.r2dbc.spi.Connection; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Parameter; -import io.r2dbc.spi.Parameters; import io.r2dbc.spi.R2dbcException; import io.r2dbc.spi.Readable; import io.r2dbc.spi.Result; @@ -48,7 +47,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.beans.BeanUtils; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.lang.Nullable; import org.springframework.r2dbc.connection.ConnectionFactoryUtils; @@ -56,7 +54,6 @@ import org.springframework.r2dbc.core.binding.BindTarget; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; /** @@ -224,124 +221,99 @@ private static String getSql(Object object) { */ class DefaultGenericExecuteSpec implements GenericExecuteSpec { - final Map byIndex; + final Params params; - final Map byName; + final List batchParams; final Supplier sqlSupplier; final StatementFilterFunction filterFunction; DefaultGenericExecuteSpec(Supplier sqlSupplier) { - this.byIndex = Collections.emptyMap(); - this.byName = Collections.emptyMap(); + this.params = DefaultParamsImpl.EMPTY_PARAMS; + this.batchParams = Collections.emptyList(); this.sqlSupplier = sqlSupplier; this.filterFunction = StatementFilterFunction.EMPTY_FILTER; } - DefaultGenericExecuteSpec(Map byIndex, Map byName, - Supplier sqlSupplier, StatementFilterFunction filterFunction) { - - this.byIndex = byIndex; - this.byName = byName; + DefaultGenericExecuteSpec(Params params, List batchParams, Supplier sqlSupplier, + StatementFilterFunction filterFunction) { + this.params = params; + this.batchParams = batchParams; this.sqlSupplier = sqlSupplier; this.filterFunction = filterFunction; } - @SuppressWarnings("deprecation") - private Parameter resolveParameter(Object value) { - if (value instanceof Parameter param) { - return param; - } - else if (value instanceof org.springframework.r2dbc.core.Parameter param) { - Object paramValue = param.getValue(); - return (paramValue != null ? Parameters.in(paramValue) : Parameters.in(param.getType())); - } - else { - return Parameters.in(value); - } - } - @Override public DefaultGenericExecuteSpec bind(int index, Object value) { assertNotPreparedOperation(); - Assert.notNull(value, () -> String.format( - "Value at index %d must not be null. Use bindNull(…) instead.", index)); - Map byIndex = new LinkedHashMap<>(this.byIndex); - byIndex.put(index, resolveParameter(value)); + Params newParams = this.params.bind(index, value); - return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction); + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); } @Override public DefaultGenericExecuteSpec bindNull(int index, Class type) { assertNotPreparedOperation(); - Map byIndex = new LinkedHashMap<>(this.byIndex); - byIndex.put(index, Parameters.in(type)); + Params newParams = this.params.bindNull(index, type); - return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction); + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); } @Override public DefaultGenericExecuteSpec bind(String name, Object value) { assertNotPreparedOperation(); - Assert.hasText(name, "Parameter name must not be null or empty"); - Assert.notNull(value, () -> String.format( - "Value for parameter %s must not be null. Use bindNull(…) instead.", name)); + Params newParams = this.params.bind(name, value); - Map byName = new LinkedHashMap<>(this.byName); - byName.put(name, resolveParameter(value)); - - return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); } @Override public DefaultGenericExecuteSpec bindNull(String name, Class type) { assertNotPreparedOperation(); - Assert.hasText(name, "Parameter name must not be null or empty"); - Map byName = new LinkedHashMap<>(this.byName); - byName.put(name, Parameters.in(type)); + Params newParams = this.params.bindNull(name, type); - return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); } @Override public GenericExecuteSpec bindValues(Map source) { assertNotPreparedOperation(); - Assert.notNull(source, "Parameter source must not be null"); - Map target = new LinkedHashMap<>(this.byName); - source.forEach((name, value) -> target.put(name, resolveParameter(value))); + Params newParams = this.params.bindValues(source); - return new DefaultGenericExecuteSpec(this.byIndex, target, this.sqlSupplier, this.filterFunction); + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); } @Override public DefaultGenericExecuteSpec bindProperties(Object source) { assertNotPreparedOperation(); - Assert.notNull(source, "Parameter source must not be null"); - - Map byName = new LinkedHashMap<>(this.byName); - for (PropertyDescriptor pd : BeanUtils.getPropertyDescriptors(source.getClass())) { - if (pd.getReadMethod() != null && pd.getReadMethod().getDeclaringClass() != Object.class) { - ReflectionUtils.makeAccessible(pd.getReadMethod()); - Object value = ReflectionUtils.invokeMethod(pd.getReadMethod(), source); - byName.put(pd.getName(), (value != null ? Parameters.in(value) : Parameters.in(pd.getPropertyType()))); - } - } - return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + Params newParams = this.params.bindProperties(source); + + return new DefaultGenericExecuteSpec(newParams, this.batchParams, this.sqlSupplier, this.filterFunction); + } + + @Override + public GenericExecuteSpec bind(UnaryOperator paramsFunction) { + assertNotPreparedOperation(); + + Assert.notNull(paramsFunction, "Params function must not be null"); + List newBatchParams = new ArrayList<>(this.batchParams); + newBatchParams.add(paramsFunction.apply(DefaultParamsImpl.EMPTY_PARAMS)); + + return new DefaultGenericExecuteSpec(this.params, newBatchParams, this.sqlSupplier, this.filterFunction); } @Override public DefaultGenericExecuteSpec filter(StatementFilterFunction filter) { Assert.notNull(filter, "StatementFilterFunction must not be null"); - return new DefaultGenericExecuteSpec( - this.byIndex, this.byName, this.sqlSupplier, this.filterFunction.andThen(filter)); + return new DefaultGenericExecuteSpec(this.params, this.batchParams, this.sqlSupplier, + this.filterFunction.andThen(filter)); } @Override @@ -396,18 +368,26 @@ private ResultFunction getResultFunction(Supplier sqlSupplier) { return statement; } - if (DefaultDatabaseClient.this.namedParameterExpander != null) { - Map remainderByName = new LinkedHashMap<>(this.byName); - Map remainderByIndex = new LinkedHashMap<>(this.byIndex); + List allParams = new ArrayList<>(this.batchParams); + + if (!this.params.byIndex().isEmpty() || !this.params.byName().isEmpty()) { + allParams.add(this.params); + } + if (!allParams.isEmpty() && DefaultDatabaseClient.this.namedParameterExpander != null) { List parameterNames = DefaultDatabaseClient.this.namedParameterExpander.getParameterNames(sql); - MapBindParameterSource namedBindings = retrieveParameters( - sql, parameterNames, remainderByName, remainderByIndex); + List> operations = new ArrayList<>(allParams.size()); + + for (Params params : allParams) { + MapBindParameterSource namedBindings = retrieveParameters(sql, parameterNames, params); PreparedOperation operation = DefaultDatabaseClient.this.namedParameterExpander.expand( sql, DefaultDatabaseClient.this.bindMarkersFactory, namedBindings); - String expanded = getRequiredSql(operation); + operations.add(operation); + } + + String expanded = getRequiredSql(operations); if (logger.isTraceEnabled()) { logger.trace("Expanded SQL [" + expanded + "]"); } @@ -415,18 +395,22 @@ private ResultFunction getResultFunction(Supplier sqlSupplier) { Statement statement = connection.createStatement(expanded); BindTarget bindTarget = new StatementWrapper(statement); - operation.bindTo(bindTarget); + for (int i = 0; i < operations.size(); i++) { + PreparedOperation operation = operations.get(i); + operation.bindTo(bindTarget); + if (operations.size() > 1 && i != operations.size() - 1) { + statement.add(); + } + } - bindByName(statement, remainderByName); - bindByIndex(statement, remainderByIndex); + applyBindings(statement, allParams); return statement; } Statement statement = connection.createStatement(sql); - bindByIndex(statement, this.byIndex); - bindByName(statement, this.byName); + applyBindings(statement, allParams); return statement; }; @@ -435,6 +419,19 @@ private ResultFunction getResultFunction(Supplier sqlSupplier) { DefaultDatabaseClient.this.executeFunction); } + private void applyBindings(Statement statement, List params) { + for (int i = 0; i < params.size(); i++) { + Params parameter = params.get(i); + if (!parameter.byIndex().isEmpty() || !parameter.byName().isEmpty()) { + bindByIndex(statement, parameter.byIndex()); + bindByName(statement, parameter.byName()); + if (params.size() > 1 && i != params.size() - 1) { + statement.add(); + } + } + } + } + private FetchSpec execute(Supplier sqlSupplier, Function> resultAdapter) { ResultFunction resultHandler = getResultFunction(sqlSupplier); return new DefaultFetchSpec<>(DefaultDatabaseClient.this, resultHandler, @@ -448,12 +445,11 @@ private Flux flatMap(Supplier sqlSupplier, Function parameterNames, - Map remainderByName, Map remainderByIndex) { + private MapBindParameterSource retrieveParameters(String sql, List parameterNames, Params params) { Map namedBindings = CollectionUtils.newLinkedHashMap(parameterNames.size()); for (String parameterName : parameterNames) { - Parameter parameter = getParameter(remainderByName, remainderByIndex, parameterNames, parameterName); + Parameter parameter = getParameter(params.byName(), params.byIndex(), parameterNames, parameterName); if (parameter == null) { throw new InvalidDataAccessApiUsageException( String.format("No parameter specified for [%s] in query [%s]", parameterName, sql)); @@ -464,18 +460,15 @@ private MapBindParameterSource retrieveParameters(String sql, List param } @Nullable - private Parameter getParameter(Map remainderByName, - Map remainderByIndex, List parameterNames, String parameterName) { - - if (this.byName.containsKey(parameterName)) { - remainderByName.remove(parameterName); - return this.byName.get(parameterName); + private Parameter getParameter(Map byName, Map byIndex, + List parameterNames, String parameterName) { + if (byName.containsKey(parameterName)) { + return byName.remove(parameterName); } int index = parameterNames.indexOf(parameterName); - if (this.byIndex.containsKey(index)) { - remainderByIndex.remove(index); - return this.byIndex.get(index); + if (byIndex.containsKey(index)) { + return byIndex.remove(index); } return null; @@ -495,6 +488,19 @@ private void bindByIndex(Statement statement, Map byIndex) { byIndex.forEach(statement::bind); } + private String getRequiredSql(List> operations) throws IllegalArgumentException { + return operations + .stream() + .map(this::getRequiredSql) + .reduce((prevSql, nextSql) -> { + if (prevSql.equals(nextSql)) { + return nextSql; + } else { + throw new IllegalArgumentException("Resulting SQL is not the same!"); + } + }).orElseThrow(() -> new IllegalArgumentException("Operations must not be empty!")); + } + private String getRequiredSql(Supplier sqlSupplier) { String sql = sqlSupplier.get(); Assert.state(StringUtils.hasText(sql), "SQL returned by supplier must not be empty"); diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultParamsImpl.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultParamsImpl.java new file mode 100644 index 000000000000..60f138cf028a --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultParamsImpl.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.spi.Parameters; +import io.r2dbc.spi.Parameter; +import org.springframework.beans.BeanUtils; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; +import org.springframework.r2dbc.core.DatabaseClient.Params; + +import java.beans.PropertyDescriptor; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Default {@link DatabaseClient.Params} implementation. + * + * @param byIndex index based bindings + * @param byName name based bindings + * @author Mark Shiryaev + * @since 6.1 + */ +record DefaultParamsImpl(Map byIndex, Map byName) implements Params { + public static final DefaultParamsImpl EMPTY_PARAMS = new DefaultParamsImpl( + Collections.emptyMap(), Collections.emptyMap()); + + @Override + public Params bind(String name, Object value) { + Assert.hasText(name, "Parameter name must not be null or empty"); + Assert.notNull(value, () -> String.format( + "Value for parameter %s must not be null. Use bindNull(…) instead.", name)); + + Map byName = new LinkedHashMap<>(this.byName); + byName.put(name, resolveParameter(value)); + return new DefaultParamsImpl(this.byIndex, byName); + } + + @Override + public Params bind(int index, Object value) { + Assert.notNull(value, () -> String.format( + "Value at index %d must not be null. Use bindNull(…) instead.", index)); + + Map byIndex = new LinkedHashMap<>(this.byIndex); + byIndex.put(index, resolveParameter(value)); + + return new DefaultParamsImpl(byIndex, this.byName); + } + + @Override + public Params bindNull(String name, Class type) { + Assert.hasText(name, "Parameter name must not be null or empty"); + + Map byName = new LinkedHashMap<>(this.byName); + byName.put(name, Parameters.in(type)); + + return new DefaultParamsImpl(this.byIndex, byName); + } + + @Override + public Params bindNull(int index, Class type) { + Map byIndex = new LinkedHashMap<>(this.byIndex); + byIndex.put(index, Parameters.in(type)); + + return new DefaultParamsImpl(byIndex, this.byName); + } + + @Override + public Map byIndex() { + return this.byIndex; + } + + @Override + public Map byName() { + return this.byName; + } + + @Override + public Params bindValues(Map source) { + Assert.notNull(source, "Parameter source must not be null"); + + Map byName = new LinkedHashMap<>(this.byName); + source.forEach((name, value) -> byName.put(name, resolveParameter(value))); + + return new DefaultParamsImpl(this.byIndex, byName); + } + + @Override + public Params bindProperties(Object source) { + Assert.notNull(source, "Parameter source must not be null"); + + Map byName = new LinkedHashMap<>(this.byName); + for (PropertyDescriptor pd : BeanUtils.getPropertyDescriptors(source.getClass())) { + if (pd.getReadMethod() != null && pd.getReadMethod().getDeclaringClass() != Object.class) { + ReflectionUtils.makeAccessible(pd.getReadMethod()); + Object value = ReflectionUtils.invokeMethod(pd.getReadMethod(), source); + byName.put(pd.getName(), (value != null ? Parameters.in(value) : Parameters.in(pd.getPropertyType()))); + } + } + + return new DefaultParamsImpl(this.byIndex, byName); + } + + @SuppressWarnings("deprecation") + private Parameter resolveParameter(Object value) { + if (value instanceof Parameter param) { + return param; + } else if (value instanceof org.springframework.r2dbc.core.Parameter param) { + Object paramValue = param.getValue(); + return (paramValue != null ? Parameters.in(paramValue) : Parameters.in(param.getType())); + } else { + return Parameters.in(value); + } + } +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java index 43a021b95b97..08d6c2529abe 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java @@ -143,16 +143,14 @@ public void executeBatchInsert() { DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); databaseClient.sql("INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)") - .bind("id", 42055) - .bind("name", "SCHAUFELRADBAGGER") - .bindNull("manual", Integer.class) - .add() - .bind("id", 2021) - .bind("name", "TOM") - .bindNull("manual", Integer.class) + .bind(params -> params + .bind("id", 42055) + .bind("name", "SCHAUFELRADBAGGER") + .bindNull("manual", Integer.class)) + .bind(params -> params.bind("id", 2021).bind("name", "TOM").bindNull("manual", Integer.class)) .fetch().rowsUpdated() .as(StepVerifier::create) - .expectNextMatches(updatedRows -> updatedRows.equals(2)) + .expectNextMatches(updatedRows -> updatedRows.equals(2L)) .verifyComplete(); databaseClient.sql("SELECT id FROM legoset") @@ -172,8 +170,7 @@ public void shouldThrowIllegalArgumentException() { DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); databaseClient.sql("INSERT INTO legoset (id, name, manual) VALUES(:my_list)") - .bind("my_list", Arrays.asList(1, "Bob", 1)) - .add() + .bind(params -> params.bind("my_list", Arrays.asList(1, "Bob", 1))) .bind("my_list", Arrays.asList(2, "Alice", 1, "next")) .fetch().rowsUpdated() .as(StepVerifier::create) diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java index 15390d2be6e0..5c2116551ae7 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -224,22 +224,20 @@ void executeBatchShouldBindValues() { DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.sql("INSERT INTO table VALUES ($1)") - .bind(0, Parameter.from("foo")) - .add() - .bind(0, Parameter.from("bar")) + .bind(params -> params.bind(0, Parameters.in("foo"))) + .bind(params -> params.bind(0, Parameters.in("bar"))) .then().as(StepVerifier::create).verifyComplete(); - verify(statement).bind(0, "foo"); - verify(statement).bind(0, "bar"); + verify(statement).bind(0, Parameters.in("foo")); + verify(statement).bind(0, Parameters.in("bar")); databaseClient.sql("INSERT INTO table VALUES ($1)") - .bind("$1", "foo") - .add() - .bind("$1", "bar") + .bind(params -> params.bind("$1", "foo")) + .bind(params -> params.bind("$1", "bar")) .then().as(StepVerifier::create).verifyComplete(); - verify(statement).bind("$1", "foo"); - verify(statement).bind("$1", "bar"); + verify(statement).bind("$1", Parameters.in("foo")); + verify(statement).bind("$1", Parameters.in("bar")); } @Test