Skip to content

Commit

Permalink
Refine Coroutines annotated controller support
Browse files Browse the repository at this point in the history
This commit refines Coroutines annotated controller support
by considering Kotlin Unit as Java void and using the right
ReactiveAdapter to support all use cases, including suspending
functions that return Flow (usual when using APIs like WebClient).

It also fixes RSocket fire and forget handling and adds related tests
for that use case.

Closes gh-24057
Closes gh-23866
  • Loading branch information
sdeleuze committed Nov 22, 2019
1 parent 21b2fc1 commit 712eac2
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 21 deletions.
Expand Up @@ -26,6 +26,7 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactor.asFlux

import kotlinx.coroutines.reactor.mono
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
Expand All @@ -51,28 +52,29 @@ internal fun <T: Any> monoToDeferred(source: Mono<T>) =
GlobalScope.async(Dispatchers.Unconfined) { source.awaitFirstOrNull() }

/**
* Invoke a suspending function converting it to [Mono] or [reactor.core.publisher.Flux]
* if necessary.
* Return {@code true} if the method is a suspending function.
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
internal fun isSuspendingFunction(method: Method) = method.kotlinFunction!!.isSuspend

/**
* Invoke a suspending function and converts it to [Mono] or [reactor.core.publisher.Flux].
*
* @author Sebastien Deleuze
* @since 5.2
*/
@Suppress("UNCHECKED_CAST")
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Any? {
internal fun invokeSuspendingFunction(method: Method, bean: Any, vararg args: Any?): Publisher<*> {
val function = method.kotlinFunction!!
return if (function.isSuspend) {
val mono = mono(Dispatchers.Unconfined) {
function.callSuspend(bean, *args.sliceArray(0..(args.size-2)))
.let { if (it == Unit) null else it }
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
if (function.returnType.classifier == Flow::class) {
mono.flatMapMany { (it as Flow<Any>).asFlux() }
}
else {
mono
}
val mono = mono(Dispatchers.Unconfined) {
function.callSuspend(bean, *args.sliceArray(0..(args.size-2))).let { if (it == Unit) null else it }
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
return if (function.returnType.classifier == Flow::class) {
mono.flatMapMany { (it as Flow<Any>).asFlux() }
}
else {
function.call(bean, *args)
mono
}
}
Expand Up @@ -30,6 +30,7 @@
import java.util.Optional;
import java.util.function.Predicate;

import kotlin.Unit;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.ReflectJvmMapping;
Expand Down Expand Up @@ -929,6 +930,9 @@ static private Class<?> getReturnType(Method method) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
if (function != null && function.isSuspend()) {
Type paramType = ReflectJvmMapping.getJavaType(function.getReturnType());
if (paramType == Unit.class) {
paramType = void.class;
}
return ResolvableType.forType(paramType).resolve(method.getReturnType());
}
}
Expand Down
Expand Up @@ -127,10 +127,13 @@ public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) {
public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
Object value;
boolean isSuspendingFunction = false;
try {
Method method = getBridgedMethod();
ReflectionUtils.makeAccessible(method);
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
isSuspendingFunction = true;
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
Expand All @@ -151,7 +154,8 @@ public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
}

MethodParameter returnType = getReturnType();
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType());
Class<?> reactiveType = (isSuspendingFunction ? value.getClass() : returnType.getParameterType());
ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(reactiveType);
return (isAsyncVoidReturnType(returnType, adapter) ?
Mono.from(adapter.toPublisher(value)) : Mono.justOrEmpty(value));
});
Expand Down
Expand Up @@ -39,6 +39,7 @@ import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler
import org.springframework.stereotype.Controller
import reactor.core.publisher.Flux
import reactor.core.publisher.ReplayProcessor
import reactor.test.StepVerifier
import java.time.Duration

Expand All @@ -50,6 +51,34 @@ import java.time.Duration
*/
class RSocketClientToServerCoroutinesIntegrationTests {

@Test
fun fireAndForget() {
Flux.range(1, 3)
.concatMap { requester.route("receive").data("Hello $it").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}

@Test
fun fireAndForgetAsync() {
Flux.range(1, 3)
.concatMap { i: Int -> requester.route("receive-async").data("Hello $i").send() }
.blockLast()
StepVerifier.create(context.getBean(ServerController::class.java).fireForgetPayloads)
.expectNext("Hello 1")
.expectNext("Hello 2")
.expectNext("Hello 3")
.thenAwait(Duration.ofMillis(50))
.thenCancel()
.verify(Duration.ofSeconds(5))
}

@Test
fun echoAsync() {
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
Expand All @@ -70,6 +99,16 @@ class RSocketClientToServerCoroutinesIntegrationTests {
.verify(Duration.ofSeconds(5))
}

@Test
fun echoStreamAsync() {
val result = requester.route("echo-stream-async").data("Hello").retrieveFlux(String::class.java)

StepVerifier.create(result)
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
.thenCancel()
.verify(Duration.ofSeconds(5))
}

@Test
fun echoChannel() {
val result = requester.route("echo-channel")
Expand Down Expand Up @@ -106,6 +145,19 @@ class RSocketClientToServerCoroutinesIntegrationTests {
@Controller
class ServerController {

val fireForgetPayloads = ReplayProcessor.create<String>()

@MessageMapping("receive")
fun receive(payload: String) {
fireForgetPayloads.onNext(payload)
}

@MessageMapping("receive-async")
suspend fun receiveAsync(payload: String) {
delay(10)
fireForgetPayloads.onNext(payload)
}

@MessageMapping("echo-async")
suspend fun echoAsync(payload: String): String {
delay(10)
Expand All @@ -123,6 +175,18 @@ class RSocketClientToServerCoroutinesIntegrationTests {
}
}

@MessageMapping("echo-stream-async")
suspend fun echoStreamAsync(payload: String): Flow<String> {
delay(10)
var i = 0
return flow {
while(true) {
delay(10)
emit("$payload ${i++}")
}
}
}

@MessageMapping("echo-channel")
fun echoChannel(payloads: Flow<String>) = payloads.map {
delay(10)
Expand Down Expand Up @@ -185,8 +249,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {

private lateinit var server: CloseableChannel

private val interceptor = FireAndForgetCountingInterceptor()

private lateinit var requester: RSocketRequester


Expand All @@ -196,7 +258,6 @@ class RSocketClientToServerCoroutinesIntegrationTests {
context = AnnotationConfigApplicationContext(ServerConfig::class.java)

server = RSocketFactory.receive()
.addResponderPlugin(interceptor)
.frameDecoder(PayloadDecoder.ZERO_COPY)
.acceptor(context.getBean(RSocketMessageHandler::class.java).responder())
.transport(TcpServerTransport.create("localhost", 7000))
Expand Down
Expand Up @@ -139,7 +139,8 @@ public Mono<HandlerResult> invoke(
try {
ReflectionUtils.makeAccessible(getBridgedMethod());
Method method = getBridgedMethod();
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())
&& CoroutinesUtils.isSuspendingFunction(method)) {
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
Expand Down

0 comments on commit 712eac2

Please sign in to comment.