diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index 040819bea667..bb1bbff89621 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -26,12 +26,11 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Predicate; -import java.util.stream.Collectors; +import kotlin.coroutines.Continuation; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.jvm.ReflectJvmMapping; @@ -398,7 +397,7 @@ private MethodParameter nested(int nestingLevel, @Nullable Integer typeIndex) { * either in the form of Java 8's {@link java.util.Optional}, any variant * of a parameter-level {@code Nullable} annotation (such as from JSR-305 * or the FindBugs set of annotations), or a language-level nullable type - * declaration in Kotlin. + * declaration or {@code Continuation} parameter in Kotlin. * @since 4.3 */ public boolean isOptional() { @@ -867,37 +866,39 @@ private static int validateIndex(Executable executable, int parameterIndex) { private static class KotlinDelegate { /** - * Check whether the specified {@link MethodParameter} represents a nullable Kotlin type - * or an optional parameter (with a default value in the Kotlin declaration). + * Check whether the specified {@link MethodParameter} represents a nullable Kotlin type, + * an optional parameter (with a default value in the Kotlin declaration) or a {@link Continuation} parameter + * used in suspending functions. */ public static boolean isOptional(MethodParameter param) { Method method = param.getMethod(); - Constructor ctor = param.getConstructor(); int index = param.getParameterIndex(); if (method != null && index == -1) { KFunction function = ReflectJvmMapping.getKotlinFunction(method); return (function != null && function.getReturnType().isMarkedNullable()); } - else { - KFunction function = null; - Predicate predicate = null; - if (method != null) { - function = ReflectJvmMapping.getKotlinFunction(method); - predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()); - } - else if (ctor != null) { - function = ReflectJvmMapping.getKotlinFunction(ctor); - predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()) || - KParameter.Kind.INSTANCE.equals(p.getKind()); + KFunction function; + Predicate predicate; + if (method != null) { + if (param.parameterType == Continuation.class) { + return true; } - if (function != null) { - List parameters = function.getParameters(); - KParameter parameter = parameters - .stream() - .filter(predicate) - .collect(Collectors.toList()) - .get(index); - return (parameter.getType().isMarkedNullable() || parameter.isOptional()); + function = ReflectJvmMapping.getKotlinFunction(method); + predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()); + } + else { + function = ReflectJvmMapping.getKotlinFunction(param.getConstructor()); + predicate = p -> KParameter.Kind.VALUE.equals(p.getKind()) || + KParameter.Kind.INSTANCE.equals(p.getKind()); + } + if (function != null) { + int i = 0; + for (KParameter kParameter : function.getParameters()) { + if (predicate.test(kParameter)) { + if (index == i++) { + return (kParameter.getType().isMarkedNullable() || kParameter.isOptional()); + } + } } } return false; diff --git a/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt b/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt index b88fbaa64b87..419698d2ff6c 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/KotlinMethodParameterTests.kt @@ -20,6 +20,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import java.lang.reflect.Method import java.lang.reflect.TypeVariable +import kotlin.coroutines.Continuation import kotlin.reflect.full.declaredFunctions import kotlin.reflect.jvm.javaMethod @@ -101,6 +102,13 @@ class KotlinMethodParameterTests { assertThat(returnGenericParameterType("suspendFun8")).isEqualTo(Object::class.java) } + @Test + fun `Continuation parameter is optional`() { + val method = this::class.java.getDeclaredMethod("suspendFun", String::class.java, Continuation::class.java) + assertThat(MethodParameter(method, 0).isOptional).isFalse() + assertThat(MethodParameter(method, 1).isOptional).isTrue() + } + private fun returnParameterType(funName: String) = returnMethodParameter(funName).parameterType private fun returnGenericParameterType(funName: String) = returnMethodParameter(funName).genericParameterType private fun returnGenericParameterTypeName(funName: String) = returnGenericParameterType(funName).typeName