From 712eac2915f73e9a59940e8a6a62123d6cea38d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Fri, 22 Nov 2019 12:11:07 +0100 Subject: [PATCH] Refine Coroutines annotated controller support This commit refines Coroutines annotated controller support by considering Kotlin Unit as Java void and using the right ReactiveAdapter to support all use cases, including suspending functions that return Flow (usual when using APIs like WebClient). It also fixes RSocket fire and forget handling and adds related tests for that use case. Closes gh-24057 Closes gh-23866 --- .../springframework/core/CoroutinesUtils.kt | 32 ++++----- .../springframework/core/MethodParameter.java | 4 ++ .../reactive/InvocableHandlerMethod.java | 8 ++- ...lientToServerCoroutinesIntegrationTests.kt | 67 ++++++++++++++++++- .../result/method/InvocableHandlerMethod.java | 3 +- 5 files changed, 93 insertions(+), 21 deletions(-) diff --git a/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt b/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt index 64b1af36f0f3..edfd76d8c932 100644 --- a/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt +++ b/spring-core/kotlin-coroutines/src/main/kotlin/org/springframework/core/CoroutinesUtils.kt @@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactor.asFlux import kotlinx.coroutines.reactor.mono +import org.reactivestreams.Publisher import reactor.core.publisher.Mono import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method @@ -51,28 +52,29 @@ internal fun monoToDeferred(source: Mono) = GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() } /** - * Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux] - * if necessary. + * Return {@code true} if the method is a suspending function. + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend + +/** + * Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux]. * * @author Sebastien Deleuze * @since 5.2 */ @Suppress("UNCHECKED_CAST") -internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? { +internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> { val function = method.kotlinFunction!! - return if (function.isSuspend) { - val mono = mono(Dispatchers.Unconfined) { - function.callSuspend(bean, *args.sliceArray(0..(args.size-2))) - .let { if (it == Unit) null else it } - }.onErrorMap(InvocationTargetException::class.java) { it.targetException } - if (function.returnType.classifier == Flow::class) { - mono.flatMapMany { (it as Flow).asFlux() } - } - else { - mono - } + val mono = mono(Dispatchers.Unconfined) { + function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it } + }.onErrorMap(InvocationTargetException::class.java) { it.targetException } + return if (function.returnType.classifier == Flow::class) { + mono.flatMapMany { (it as Flow).asFlux() } } else { - function.call(bean, *args) + mono } } diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index e3ae0562bac8..bcb1a357ef92 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -30,6 +30,7 @@ import java.util.Optional; import java.util.function.Predicate; +import kotlin.Unit; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.jvm.ReflectJvmMapping; @@ -929,6 +930,9 @@ static private Class getReturnType(Method method) { KFunction function = ReflectJvmMapping.getKotlinFunction(method); if (function != null && function.isSuspend()) { Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType()); + if (paramType == Unit.class) { + paramType = void.class; + } return ResolvableType.forType(paramType).resolve(method.getReturnType()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java index 0db41014e4fb..8ded74c07d65 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java @@ -127,10 +127,13 @@ public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) { public Mono invoke(Message message, Object... providedArgs) { return getMethodArgumentValues(message, providedArgs).flatMap(args -> { Object value; + boolean isSuspendingFunction = false; try { Method method = getBridgedMethod(); ReflectionUtils.makeAccessible(method); - if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { + if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass()) + && CoroutinesUtils.isSuspendingFunction(method)) { + isSuspendingFunction = true; value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); } else { @@ -151,7 +154,8 @@ public Mono invoke(Message message, Object... providedArgs) { } MethodParameter returnType = getReturnType(); - ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType()); + Class reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType()); + ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType); return (isAsyncVoidReturnType(returnType, adapter) ? Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value)); }); diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt index cbc5a5a77a73..d72da8a94fef 100644 --- a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt @@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler import org.springframework.stereotype.Controller import reactor.core.publisher.Flux +import reactor.core.publisher.ReplayProcessor import reactor.test.StepVerifier import java.time.Duration @@ -50,6 +51,34 @@ import java.time.Duration */ class RSocketClientToServerCoroutinesIntegrationTests { + @Test + fun fireAndForget() { + Flux.range(1, 3) + .concatMap { requester.route("receive").data("Hello $it").send() } + .blockLast() + StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads) + .expectNext("Hello 1") + .expectNext("Hello 2") + .expectNext("Hello 3") + .thenAwait(Duration.ofMillis(50)) + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + + @Test + fun fireAndForgetAsync() { + Flux.range(1, 3) + .concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() } + .blockLast() + StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads) + .expectNext("Hello 1") + .expectNext("Hello 2") + .expectNext("Hello 3") + .thenAwait(Duration.ofMillis(50)) + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + @Test fun echoAsync() { val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) } @@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests { .verify(Duration.ofSeconds(5)) } + @Test + fun echoStreamAsync() { + val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java) + + StepVerifier.create(result) + .expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7") + .thenCancel() + .verify(Duration.ofSeconds(5)) + } + @Test fun echoChannel() { val result = requester.route("echo-channel") @@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests { @Controller class ServerController { + val fireForgetPayloads = ReplayProcessor.create() + + @MessageMapping("receive") + fun receive(payload: String) { + fireForgetPayloads.onNext(payload) + } + + @MessageMapping("receive-async") + suspend fun receiveAsync(payload: String) { + delay(10) + fireForgetPayloads.onNext(payload) + } + @MessageMapping("echo-async") suspend fun echoAsync(payload: String): String { delay(10) @@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests { } } + @MessageMapping("echo-stream-async") + suspend fun echoStreamAsync(payload: String): Flow { + delay(10) + var i = 0 + return flow { + while(true) { + delay(10) + emit("$payload ${i++}") + } + } + } + @MessageMapping("echo-channel") fun echoChannel(payloads: Flow) = payloads.map { delay(10) @@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests { private lateinit var server: CloseableChannel - private val interceptor = FireAndForgetCountingInterceptor() - private lateinit var requester: RSocketRequester @@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests { context = AnnotationConfigApplicationContext(ServerConfig::class.java) server = RSocketFactory.receive() - .addResponderPlugin(interceptor) .frameDecoder(PayloadDecoder.ZERO_COPY) .acceptor(context.getBean(RSocketMessageHandler::class.java).responder()) .transport(TcpServerTransport.create("localhost", 7000)) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index ada5628945eb..c419fb491ed9 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -139,7 +139,8 @@ public Mono invoke( try { ReflectionUtils.makeAccessible(getBridgedMethod()); Method method = getBridgedMethod(); - if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { + if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass()) + && CoroutinesUtils.isSuspendingFunction(method)) { value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); } else {