diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index e81709b4ac37..6cf351560bec 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import java.util.List; import java.util.stream.Stream; +import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import org.springframework.core.CoroutinesUtils; @@ -129,15 +130,16 @@ public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) { * @param providedArgs optional list of argument values to match by type * @return a Mono with a {@link HandlerResult} */ - @SuppressWarnings("KotlinInternalInJava") + @SuppressWarnings({"KotlinInternalInJava", "unchecked"}) public Mono invoke( ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) { return getMethodArgumentValues(exchange, bindingContext, providedArgs).flatMap(args -> { Object value; + Method method = getBridgedMethod(); + boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method); try { - Method method = getBridgedMethod(); - if (KotlinDetector.isSuspendingFunction(method)) { + if (isSuspendingFunction) { value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); } else { @@ -163,10 +165,16 @@ public Mono invoke( } MethodParameter returnType = getReturnType(); - ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(returnType.getParameterType()); - boolean asyncVoid = isAsyncVoidReturnType(returnType, adapter); - if ((value == null || asyncVoid) && isResponseHandled(args, exchange)) { - return (asyncVoid ? Mono.from(adapter.toPublisher(value)) : Mono.empty()); + if (isResponseHandled(args, exchange)) { + Class parameterType = returnType.getParameterType(); + ReactiveAdapter adapter = this.reactiveAdapterRegistry.getAdapter(parameterType); + boolean asyncVoid = isAsyncVoidReturnType(returnType, adapter); + if (value == null || asyncVoid) { + return (asyncVoid ? Mono.from(adapter.toPublisher(value)) : Mono.empty()); + } + if (isSuspendingFunction && parameterType == void.class) { + return Mono.from((Publisher) value); + } } HandlerResult result = new HandlerResult(this, value, returnType, bindingContext); diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/KotlinInvocableHandlerMethodTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/KotlinInvocableHandlerMethodTests.kt index 133046955275..7975a7dd853e 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/KotlinInvocableHandlerMethodTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/KotlinInvocableHandlerMethodTests.kt @@ -34,6 +34,7 @@ import org.springframework.web.reactive.result.method.annotation.ContinuationHan import reactor.core.publisher.Mono import reactor.test.StepVerifier import java.lang.reflect.Method +import java.time.Duration import kotlin.reflect.jvm.javaMethod class KotlinInvocableHandlerMethodTests { @@ -89,11 +90,9 @@ class KotlinInvocableHandlerMethodTests { val response = this.exchange.response this.resolvers.add(stubResolver(response)) val method = CoroutinesController::response.javaMethod!! - val result = invoke(CoroutinesController(), method) + val result = invokeForResult(CoroutinesController(), method, response) - StepVerifier.create(result) - .consumeNextWith { StepVerifier.create(it.returnValue as Mono<*>).verifyComplete() } - .verifyComplete() + assertThat(result).`as`("Expected no result (i.e. fully handled)").isNull() assertThat(this.exchange.response.headers.getFirst("foo")).isEqualTo("bar") } @@ -105,6 +104,10 @@ class KotlinInvocableHandlerMethodTests { assertHandlerResultValue(result, "success:foo") } + private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? { + return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5)) + } + private fun invoke(handler: Any, method: Method, vararg providedArgs: Any?): Mono { val invocable = InvocableHandlerMethod(handler, method) invocable.setArgumentResolvers(this.resolvers)