diff --git a/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt b/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt index 64b1af36f0f3..edfd76d8c932 100644 --- a/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt +++ b/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt @@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactor.asFlux import kotlinx.coroutines.reactor.mono +import org.reactivestreams.Publisher import reactor.core.publisher.Mono import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method @@ -51,28 +52,29 @@ internal fun monoToDeferred(source: Mono) = GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() } /** - * Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux] - * if necessary. + * Return {@code true} if the method is a suspending function. + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend + +/** + * Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux]. * * @author Sebastien Deleuze * @since 5.2 */ @Suppress("UNCHECKED_CAST") -internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? { +internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> { val function = method.kotlinFunction!! - return if (function.isSuspend) { - val mono = mono(Dispatchers.Unconfined) { - function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) - .let { if (it == Unit) null else it } - }.onErrorMap(InvocationTargetException::class.java) { it.targetException } - if (function.returnType.classifier == Flow::class) { - mono.flatMapMany { (it as Flow).asFlux() } - } - else { - mono - } + val mono = mono(Dispatchers.Unconfined) { + function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it } + }.onErrorMap(InvocationTargetException::class.java) { it.targetException } + return if (function.returnType.classifier == Flow::class) { + mono.flatMapMany { (it as Flow).asFlux() } } else { - function.call(bean, *args) + mono } } diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index e3ae0562bac8..bcb1a357ef92 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -30,6 +30,7 @@ import java.util.Optional; import java.util.function.Predicate; +import kotlin.Unit; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.jvm.ReflectJvmMapping; @@ -929,6 +930,9 @@ static private Class getReturnType(Method method) { KFunction function = ReflectJvmMapping.getKotlinFunction(method); if (function != null && function.isSuspend()) { Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType()); + if (paramType == Unit.class) { + paramType = void.class; + } return ResolvableType.forType(paramType).resolve(method.getReturnType()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java index 0db41014e4fb..8ded74c07d65 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java @@ -127,10 +127,13 @@ public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) { public Mono invoke(Message message, Object... providedArgs) { return getMethodArgumentValues(message, providedArgs).flatMap(args -> { Object value; + boolean isSuspendingFunction = false; try { Method method = getBridgedMethod(); ReflectionUtils.makeAccessible(method); - if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { + if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass()) + && CoroutinesUtils.isSuspendingFunction(method)) { + isSuspendingFunction = true; value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); } else { @@ -151,7 +154,8 @@ public Mono invoke(Message message, Object... providedArgs) { } MethodParameter returnType = getReturnType(); - ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType()); + Class reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType()); + ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType); return (isAsyncVoidReturnType(returnType, adapter) ? Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value)); }); diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt index cbc5a5a77a73..d72da8a94fef 100644 --- a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt @@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler import org.springframework.stereotype.Controller import reactor.core.publisher.Flux +import reactor.core.publisher.ReplayProcessor import reactor.test.StepVerifier import java.time.Duration @@ -50,6 +51,34 @@ import java.time.Duration */ class RSocketClientToServerCoroutinesIntegrationTests { + @Test + fun fireAndForget() { + Flux.range(1, 3) + .concatMap { requester.route("receive").data("Hello $it").send() } + .blockLast() + StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads) + .expectNext("Hello 1") + .expectNext("Hello 2") + .expectNext("Hello 3") + .thenAwait(Duration.ofMillis(50)) + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + + @Test + fun fireAndForgetAsync() { + Flux.range(1, 3) + .concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() } + .blockLast() + StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads) + .expectNext("Hello 1") + .expectNext("Hello 2") + .expectNext("Hello 3") + .thenAwait(Duration.ofMillis(50)) + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + @Test fun echoAsync() { val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) } @@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests { .verify(Duration.ofSeconds(5)) } + @Test + fun echoStreamAsync() { + val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java) + + StepVerifier.create(result) + .expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7") + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + @Test fun echoChannel() { val result = requester.route("echo-channel") @@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests { @Controller class ServerController { + val fireForgetPayloads = ReplayProcessor.create() + + @MessageMapping("receive") + fun receive(payload: String) { + fireForgetPayloads.onNext(payload) + } + + @MessageMapping("receive-async") + suspend fun receiveAsync(payload: String) { + delay(10) + fireForgetPayloads.onNext(payload) + } + @MessageMapping("echo-async") suspend fun echoAsync(payload: String): String { delay(10) @@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests { } } + @MessageMapping("echo-stream-async") + suspend fun echoStreamAsync(payload: String): Flow { + delay(10) + var i = 0 + return flow { + while(true) { + delay(10) + emit("$payload ${i++}") + } + } + } + @MessageMapping("echo-channel") fun echoChannel(payloads: Flow) = payloads.map { delay(10) @@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests { private lateinit var server: CloseableChannel - private val interceptor = FireAndForgetCountingInterceptor() - private lateinit var requester: RSocketRequester @@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests { context = AnnotationConfigApplicationContext(ServerConfig::class.java) server = RSocketFactory.receive() - .addResponderPlugin(interceptor) .frameDecoder(PayloadDecoder.ZERO_COPY) .acceptor(context.getBean(RSocketMessageHandler::class.java).responder()) .transport(TcpServerTransport.create("localhost", 7000)) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index ada5628945eb..c419fb491ed9 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -139,7 +139,8 @@ public Mono invoke( try { ReflectionUtils.makeAccessible(getBridgedMethod()); Method method = getBridgedMethod(); - if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { + if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass()) + && CoroutinesUtils.isSuspendingFunction(method)) { value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); } else {