Skip to content

Commit

Permalink
Merge pull request #1856 from kazuki43zoo/gh-1237
Browse files Browse the repository at this point in the history
Allow using actual argument name as bind parameter on a single collection
  • Loading branch information
kazuki43zoo committed Mar 21, 2020
2 parents dc02ced + ce14cec commit 060a397
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 16 deletions.
39 changes: 37 additions & 2 deletions src/main/java/org/apache/ibatis/reflection/ParamNameResolver.java
Expand Up @@ -17,8 +17,11 @@

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedMap;
import java.util.TreeMap;

Expand All @@ -32,6 +35,8 @@ public class ParamNameResolver {

public static final String GENERIC_NAME_PREFIX = "param";

private final boolean useActualParamName;

/**
* <p>
* The key is the index and the value is the name of the parameter.<br />
Expand All @@ -50,6 +55,7 @@ public class ParamNameResolver {
private boolean hasParamAnnotation;

public ParamNameResolver(Configuration config, Method method) {
this.useActualParamName = config.isUseActualParamName();
final Class<?>[] paramTypes = method.getParameterTypes();
final Annotation[][] paramAnnotations = method.getParameterAnnotations();
final SortedMap<Integer, String> map = new TreeMap<>();
Expand All @@ -70,7 +76,7 @@ public ParamNameResolver(Configuration config, Method method) {
}
if (name == null) {
// @Param was not specified.
if (config.isUseActualParamName()) {
if (useActualParamName) {
name = getActualParamName(method, paramIndex);
}
if (name == null) {
Expand Down Expand Up @@ -118,7 +124,8 @@ public Object getNamedParams(Object[] args) {
if (args == null || paramCount == 0) {
return null;
} else if (!hasParamAnnotation && paramCount == 1) {
return args[names.firstKey()];
Object value = args[names.firstKey()];
return wrapToMapIfCollection(value, useActualParamName ? names.get(0) : null);
} else {
final Map<String, Object> param = new ParamMap<>();
int i = 0;
Expand All @@ -135,4 +142,32 @@ public Object getNamedParams(Object[] args) {
return param;
}
}

/**
* Wrap to a {@link ParamMap} if object is {@link Collection} or array.
*
* @param object a parameter object
* @param actualParamName an actual parameter name
* (If specify a name, set an object to {@link ParamMap} with specified name)
* @return a {@link ParamMap}
* @since 3.5.5
*/
public static Object wrapToMapIfCollection(Object object, String actualParamName) {
if (object instanceof Collection) {
ParamMap<Object> map = new ParamMap<>();
map.put("collection", object);
if (object instanceof List) {
map.put("list", object);
}
Optional.ofNullable(actualParamName).ifPresent(name -> map.put(name, object));
return map;
} else if (object != null && object.getClass().isArray()) {
ParamMap<Object> map = new ParamMap<>();
map.put("array", object);
Optional.ofNullable(actualParamName).ifPresent(name -> map.put(name, object));
return map;
}
return object;
}

}
@@ -1,5 +1,5 @@
/**
* Copyright 2009-2019 the original author or authors.
* Copyright 2009-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import org.apache.ibatis.executor.result.DefaultMapResultHandler;
import org.apache.ibatis.executor.result.DefaultResultContext;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.reflection.ParamNameResolver;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
Expand Down Expand Up @@ -317,21 +318,13 @@ private boolean isCommitOrRollbackRequired(boolean force) {
}

private Object wrapCollection(final Object object) {
if (object instanceof Collection) {
StrictMap<Object> map = new StrictMap<>();
map.put("collection", object);
if (object instanceof List) {
map.put("list", object);
}
return map;
} else if (object != null && object.getClass().isArray()) {
StrictMap<Object> map = new StrictMap<>();
map.put("array", object);
return map;
}
return object;
return ParamNameResolver.wrapToMapIfCollection(object, null);
}

/**
* @deprecated Since 3.5.5
*/
@Deprecated
public static class StrictMap<V> extends HashMap<String, V> {

private static final long serialVersionUID = -5741767162221585340L;
Expand Down
@@ -0,0 +1,148 @@
/**
* Copyright 2009-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ibatis.submitted.param_name_resolve;

import org.apache.ibatis.annotations.Select;
import org.apache.ibatis.io.Resources;
import org.apache.ibatis.jdbc.ScriptRunner;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.session.SqlSessionFactoryBuilder;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.Reader;
import java.sql.Connection;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;

class ActualParamNameTest {

private static SqlSessionFactory sqlSessionFactory;

@BeforeAll
static void setUp() throws Exception {
// create an SqlSessionFactory
try (Reader reader = Resources
.getResourceAsReader("org/apache/ibatis/submitted/param_name_resolve/mybatis-config.xml")) {
sqlSessionFactory = new SqlSessionFactoryBuilder().build(reader);
sqlSessionFactory.getConfiguration().addMapper(Mapper.class);
}

// populate in-memory database
try (Connection conn = sqlSessionFactory.getConfiguration().getEnvironment().getDataSource().getConnection();
Reader reader = Resources
.getResourceAsReader("org/apache/ibatis/submitted/param_name_resolve/CreateDB.sql")) {
ScriptRunner runner = new ScriptRunner(conn);
runner.setLogWriter(null);
runner.runScript(reader);
}
}

@Test
void testSingleListParameterWhenUseActualParamNameIsTrue() {
try (SqlSession sqlSession = sqlSessionFactory.openSession()) {
Mapper mapper = sqlSession.getMapper(Mapper.class);
// use actual name
{
long count = mapper.getUserCountUsingList(Arrays.asList(1, 2));
assertEquals(2, count);
}
// use 'collection' as alias
{
long count = mapper.getUserCountUsingListWithAliasIsCollection(Arrays.asList(1, 2));
assertEquals(2, count);
}
// use 'list' as alias
{
long count = mapper.getUserCountUsingListWithAliasIsList(Arrays.asList(1, 2));
assertEquals(2, count);
}
}
}

@Test
void testSingleArrayParameterWhenUseActualParamNameIsTrue() {
try (SqlSession sqlSession = sqlSessionFactory.openSession()) {
Mapper mapper = sqlSession.getMapper(Mapper.class);
// use actual name
{
long count = mapper.getUserCountUsingArray(1, 2);
assertEquals(2, count);
}
// use 'array' as alias
{
long count = mapper.getUserCountUsingArrayWithAliasArray(1, 2);
assertEquals(2, count);
}
}
}

interface Mapper {
@Select({
"<script>",
" select count(*) from users u where u.id in",
" <foreach item='item' index='index' collection='ids' open='(' separator=',' close=')'>",
" #{item}",
" </foreach>",
"</script>"
})
Long getUserCountUsingList(List<Integer> ids);

@Select({
"<script>",
" select count(*) from users u where u.id in",
" <foreach item='item' index='index' collection='collection' open='(' separator=',' close=')'>",
" #{item}",
" </foreach>",
"</script>"
})
Long getUserCountUsingListWithAliasIsCollection(List<Integer> ids);

@Select({
"<script>",
" select count(*) from users u where u.id in",
" <foreach item='item' index='index' collection='list' open='(' separator=',' close=')'>",
" #{item}",
" </foreach>",
"</script>"
})
Long getUserCountUsingListWithAliasIsList(List<Integer> ids);

@Select({
"<script>",
" select count(*) from users u where u.id in",
" <foreach item='item' index='index' collection='ids' open='(' separator=',' close=')'>",
" #{item}",
" </foreach>",
"</script>"
})
Long getUserCountUsingArray(Integer... ids);

@Select({
"<script>",
" select count(*) from users u where u.id in",
" <foreach item='item' index='index' collection='array' open='(' separator=',' close=')'>",
" #{item}",
" </foreach>",
"</script>"
})
Long getUserCountUsingArrayWithAliasArray(Integer... ids);
}

}
@@ -0,0 +1,24 @@
--
-- Copyright 2009-2020 the original author or authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

drop table users if exists;

create table users (
id int,
name varchar(20)
);

insert into users (id, name) values (1, 'User1'), (2, 'User2'), (3, 'User3');

0 comments on commit 060a397

Please sign in to comment.