From 842b424acd8a4cbffd7f0dffae33fcce61fbce25 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 18 Nov 2019 17:27:09 +0000 Subject: [PATCH] Use method signature to refine RSocket @MessageMapping Before this change an @MessageMapping could be matched to any RSocket interaction type, which is arguably too flexible, makes it difficult to reason what would happen in case of a significant mismatch of cardinality, e.g. request for Fire-And-Forget (1-to-0) mapped to a method that returns Flux, and could result in payloads being ignored, or not seen unintentionally. This commit checks @ConnectMapping method on startup and rejects them if they return any values (sync or async). It also refines each @MessageMapping to match only the RSocket interaction type it fits based on the input and output cardinality of the handler method. Subsequently if a request is not matched, we'll do a second search to identify partial matches (by route only) and raise a helpful error that explains which interaction type is actually supported. The reference docs has been updated to explain the options. Closes gh-23999 --- .../handler/annotation/MessageMapping.java | 36 +++++-- .../AbstractMethodMessageHandler.java | 33 +++++- ...andlerMethodArgumentResolverComposite.java | 6 +- .../invocation/reactive/InvocableHelper.java | 8 ++ .../RSocketFrameTypeMessageCondition.java | 88 ++++++++++++--- .../support/RSocketMessageHandler.java | 101 ++++++++++++++++-- .../rsocket/RSocketBufferLeakTests.java | 15 +-- ...RSocketClientToServerIntegrationTests.java | 4 +- ...RSocketFrameTypeMessageConditionTests.java | 27 ++++- .../support/RSocketMessageHandlerTests.java | 100 +++++++++++++++++ ...lientToServerCoroutinesIntegrationTests.kt | 4 +- src/docs/asciidoc/rsocket.adoc | 94 +++++++++++++--- 12 files changed, 450 insertions(+), 66 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/MessageMapping.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/MessageMapping.java index 27c4076c8d0a..d6530a69873e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/MessageMapping.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/MessageMapping.java @@ -64,18 +64,34 @@ * authenticated user. * * - *

How the return value is handled depends on the processing scenario. For - * STOMP over WebSocket, it is turned into a message and sent to a default response - * destination or to a custom destination specified with an {@link SendTo @SendTo} - * or {@link org.springframework.messaging.simp.annotation.SendToUser @SendToUser} - * annotation. For RSocket, the response is used to reply to the stream request. + *

Return value handling depends on the processing scenario: + *

* - *

Specializations of this annotation including - * {@link org.springframework.messaging.simp.annotation.SubscribeMapping @SubscribeMapping} or + *

Specializations of this annotation include + * {@link org.springframework.messaging.simp.annotation.SubscribeMapping @SubscribeMapping} + * (e.g. STOMP subscriptions) and * {@link org.springframework.messaging.rsocket.annotation.ConnectMapping @ConnectMapping} - * further narrow the mapping by message type. Both can be combined with a - * type-level {@code @MessageMapping} for declaring a common pattern prefix - * (or prefixes). + * (e.g. RSocket connections). Both narrow the primary mapping further and also match + * against the message type. Both can be combined with a type-level + * {@code @MessageMapping} that declares a common pattern prefix (or prefixes). + * + *

For further details on the use of this annotation in different contexts, + * see the following sections of the Spring Framework reference: + *

* *

NOTE: When using controller interfaces (e.g. for AOP proxying), * make sure to consistently put all your mapping annotations - such as diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractMethodMessageHandler.java index d4a04d086be9..b0de4fde096a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractMethodMessageHandler.java @@ -232,6 +232,15 @@ public MultiValueMap getDestinationLookup() { return CollectionUtils.unmodifiableMultiValueMap(this.destinationLookup); } + /** + * Return the argument resolvers initialized during {@link #afterPropertiesSet()}. + * Primarily for internal use in sub-classes. + * @since 5.2.2 + */ + protected HandlerMethodArgumentResolverComposite getArgumentResolvers() { + return this.invocableHelper.getArgumentResolvers(); + } + @Override public void afterPropertiesSet() { @@ -377,6 +386,7 @@ protected final void registerHandlerMethod(Object handler, Method method, T mapp oldHandlerMethod.getBean() + "' bean method\n" + oldHandlerMethod + " mapped."); } + mapping = extendMapping(mapping, newHandlerMethod); this.handlerMethods.put(mapping, newHandlerMethod); for (String pattern : getDirectLookupMappings(mapping)) { @@ -402,6 +412,21 @@ private HandlerMethod createHandlerMethod(Object handler, Method method) { return handlerMethod; } + /** + * This method is invoked just before mappings are added. It allows + * sub-classes to update the mapping with the {@link HandlerMethod} in mind. + * This can be useful when the method signature is used to refine the + * mapping, e.g. based on the cardinality of input and output. + *

By default this method returns the mapping that is passed in. + * @param mapping the mapping to be added + * @param handlerMethod the target handler for the mapping + * @return a new mapping or the same + * @since 5.2.2 + */ + protected T extendMapping(T mapping, HandlerMethod handlerMethod) { + return mapping; + } + /** * Return String-based destinations for the given mapping, if any, that can * be used to find matches with a direct lookup (i.e. non-patterns). @@ -414,7 +439,13 @@ private HandlerMethod createHandlerMethod(Object handler, Method method) { @Override public Mono handleMessage(Message message) throws MessagingException { - Match match = getHandlerMethod(message); + Match match = null; + try { + match = getHandlerMethod(message); + } + catch (Exception ex) { + return Mono.error(ex); + } if (match == null) { // handleNoMatch would have been invoked already return Mono.empty(); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/HandlerMethodArgumentResolverComposite.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/HandlerMethodArgumentResolverComposite.java index fa0f28c244a6..6822b29f6344 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/HandlerMethodArgumentResolverComposite.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/HandlerMethodArgumentResolverComposite.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ * @author Rossen Stoyanchev * @since 5.2 */ -class HandlerMethodArgumentResolverComposite implements HandlerMethodArgumentResolver { +public class HandlerMethodArgumentResolverComposite implements HandlerMethodArgumentResolver { protected final Log logger = LogFactory.getLog(getClass()); @@ -125,7 +125,7 @@ public Mono resolveArgument(MethodParameter parameter, Message messag * the given method parameter. */ @Nullable - private HandlerMethodArgumentResolver getArgumentResolver(MethodParameter parameter) { + public HandlerMethodArgumentResolver getArgumentResolver(MethodParameter parameter) { HandlerMethodArgumentResolver result = this.argumentResolverCache.get(parameter); if (result == null) { for (HandlerMethodArgumentResolver methodArgumentResolver : this.argumentResolvers) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java index b144ae59a809..d4ba44d4a4d8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java @@ -80,6 +80,14 @@ public void addArgumentResolvers(List r this.argumentResolvers.addResolvers(resolvers); } + /** + * Return the configured resolvers. + * @since 5.2.2 + */ + public HandlerMethodArgumentResolverComposite getArgumentResolvers() { + return this.argumentResolvers; + } + /** * Add the return value handlers to use for message handling and exception * handling methods. diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageCondition.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageCondition.java index ebbadd3cf433..3288df528087 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageCondition.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageCondition.java @@ -56,6 +56,13 @@ public class RSocketFrameTypeMessageCondition extends AbstractMessageCondition frameTypes; @@ -68,6 +75,10 @@ public RSocketFrameTypeMessageCondition(Collection frameTypes) { this.frameTypes = Collections.unmodifiableSet(new LinkedHashSet<>(frameTypes)); } + private RSocketFrameTypeMessageCondition() { + this.frameTypes = Collections.emptySet(); + } + public Set getFrameTypes() { return this.frameTypes; @@ -124,18 +135,71 @@ public int compareTo(RSocketFrameTypeMessageCondition other, Message message) } - /** Condition to match the initial SETUP frame and subsequent metadata pushes. */ - public static final RSocketFrameTypeMessageCondition CONNECT_CONDITION = - new RSocketFrameTypeMessageCondition( - FrameType.SETUP, - FrameType.METADATA_PUSH); + /** + * Return a condition for matching the RSocket request interaction type with + * that is selected based on the delcared request and response cardinality + * of some handler method. + *

The table below shows the selections made: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Request CardinalityResponse CardinalityInteraction Types
0,10Fire-And-Forget, Request-Response
0,11Request-Response
0,12Request-Stream
2AnyRequest-Channel
+ * @param cardinalityIn -- the request cardinality: 1 for a single payload, + * 2 for many payloads, and 0 if input is not handled. + * @param cardinalityOut -- the response cardinality: 0 for no output + * payloads, 1 for a single payload, and 2 for many payloads. + * @return a condition to use for matching the interaction type + * @since 5.2.2 + */ + public static RSocketFrameTypeMessageCondition getCondition(int cardinalityIn, int cardinalityOut) { + switch (cardinalityIn) { + case 0: + case 1: + switch (cardinalityOut) { + case 0: return FF_RR_CONDITION; + case 1: return RR_CONDITION; + case 2: return RS_CONDITION; + default: throw new IllegalStateException("Invalid cardinality: " + cardinalityOut); + } + case 2: + return RC_CONDITION; + default: + throw new IllegalStateException("Invalid cardinality: " + cardinalityIn); + } + } + + + private static final RSocketFrameTypeMessageCondition FF_CONDITION = from(FrameType.REQUEST_FNF); + private static final RSocketFrameTypeMessageCondition RR_CONDITION = from(FrameType.REQUEST_RESPONSE); + private static final RSocketFrameTypeMessageCondition RS_CONDITION = from(FrameType.REQUEST_STREAM); + private static final RSocketFrameTypeMessageCondition RC_CONDITION = from(FrameType.REQUEST_CHANNEL); + private static final RSocketFrameTypeMessageCondition FF_RR_CONDITION = FF_CONDITION.combine(RR_CONDITION); - /** Condition to match one of the 4 stream request types. */ - public static final RSocketFrameTypeMessageCondition REQUEST_CONDITION = - new RSocketFrameTypeMessageCondition( - FrameType.REQUEST_FNF, - FrameType.REQUEST_RESPONSE, - FrameType.REQUEST_STREAM, - FrameType.REQUEST_CHANNEL); + private static RSocketFrameTypeMessageCondition from(FrameType... frameTypes) { + return new RSocketFrameTypeMessageCondition(frameTypes); + } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index 7ee94e85c6c6..17aa48e99b01 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -19,6 +19,9 @@ import java.lang.reflect.AnnotatedElement; import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; import io.rsocket.ConnectionSetupPayload; import io.rsocket.RSocket; @@ -28,6 +31,8 @@ import reactor.core.publisher.Mono; import org.springframework.beans.BeanUtils; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.codec.Decoder; @@ -37,8 +42,11 @@ import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.handler.CompositeMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.HandlerMethod; +import org.springframework.messaging.handler.MessageCondition; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.reactive.MessageMappingMessageHandler; +import org.springframework.messaging.handler.annotation.reactive.PayloadMethodArgumentResolver; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.messaging.rsocket.ClientRSocketFactoryConfigurer; import org.springframework.messaging.rsocket.MetadataExtractor; @@ -55,12 +63,27 @@ * Extension of {@link MessageMappingMessageHandler} for handling RSocket * requests with {@link ConnectMapping @ConnectMapping} and * {@link MessageMapping @MessageMapping} methods. - *

Use {@link #responder()} to obtain a {@link SocketAcceptor} adapter to - * plug in as responder into an {@link io.rsocket.RSocketFactory}. - *

Use {@link #clientResponder(RSocketStrategies, Object...)} to obtain a - * client responder configurer + * + *

For server scenarios this class can be declared as a bean in Spring + * configuration and that would detect {@code @MessageMapping} methods in + * {@code @Controller} beans. What beans are checked can be changed through a + * {@link #setHandlerPredicate(Predicate) handlerPredicate}. Given an instance + * of this class, you can then use {@link #responder()} to obtain a + * {@link SocketAcceptor} adapter to register with the + * {@link io.rsocket.RSocketFactory}. + * + *

For client scenarios, possibly in the same process as a server, consider + * consider using the static factory method + * {@link #clientResponder(RSocketStrategies, Object...)} to obtain a client + * responder to be registered with an * {@link org.springframework.messaging.rsocket.RSocketRequester.Builder#rsocketFactory - * RSocketRequester}. + * RSocketRequester.Builder}. + * + *

For {@code @MessageMapping} methods, this class automatically determines + * the RSocket interaction type based on the input and output cardinality of the + * method. See the + * + * "Annotated Responders" section of the Spring Framework reference for more details. * * @author Rossen Stoyanchev * @since 5.2 @@ -263,6 +286,17 @@ public void afterPropertiesSet() { getArgumentResolverConfigurer().addCustomResolver(new RSocketRequesterMethodArgumentResolver()); super.afterPropertiesSet(); + + getHandlerMethods().forEach((composite, handler) -> { + if (composite.getMessageConditions().contains(RSocketFrameTypeMessageCondition.CONNECT_CONDITION)) { + MethodParameter returnType = handler.getReturnType(); + if (getCardinality(returnType) > 0) { + throw new IllegalStateException( + "Invalid @ConnectMapping method. " + + "Return type must be void or a void async type: " + handler); + } + } + }); } @Override @@ -279,10 +313,9 @@ protected List initReturnValueHandler protected CompositeMessageCondition getCondition(AnnotatedElement element) { MessageMapping ann1 = AnnotatedElementUtils.findMergedAnnotation(element, MessageMapping.class); if (ann1 != null && ann1.value().length > 0) { - String[] patterns = processDestinations(ann1.value()); return new CompositeMessageCondition( - RSocketFrameTypeMessageCondition.REQUEST_CONDITION, - new DestinationPatternsMessageCondition(patterns, obtainRouteMatcher())); + RSocketFrameTypeMessageCondition.EMPTY_CONDITION, + new DestinationPatternsMessageCondition(processDestinations(ann1.value()), obtainRouteMatcher())); } ConnectMapping ann2 = AnnotatedElementUtils.findMergedAnnotation(element, ConnectMapping.class); if (ann2 != null) { @@ -294,6 +327,45 @@ protected CompositeMessageCondition getCondition(AnnotatedElement element) { return null; } + @Override + protected CompositeMessageCondition extendMapping(CompositeMessageCondition composite, HandlerMethod handler) { + + List> conditions = composite.getMessageConditions(); + Assert.isTrue(conditions.size() == 2 && + conditions.get(0) instanceof RSocketFrameTypeMessageCondition && + conditions.get(1) instanceof DestinationPatternsMessageCondition, + "Unexpected message condition types"); + + if (conditions.get(0) != RSocketFrameTypeMessageCondition.EMPTY_CONDITION) { + return composite; + } + + int responseCardinality = getCardinality(handler.getReturnType()); + int requestCardinality = 0; + for (MethodParameter parameter : handler.getMethodParameters()) { + if (getArgumentResolvers().getArgumentResolver(parameter) instanceof PayloadMethodArgumentResolver) { + requestCardinality = getCardinality(parameter); + } + } + + return new CompositeMessageCondition( + RSocketFrameTypeMessageCondition.getCondition(requestCardinality, responseCardinality), + conditions.get(1)); + } + + private int getCardinality(MethodParameter parameter) { + Class clazz = parameter.getParameterType(); + ReactiveAdapter adapter = getReactiveAdapterRegistry().getAdapter(clazz); + if (adapter == null) { + return clazz.equals(void.class) ? 0 : 1; + } + else if (parameter.nested().getNestedParameterType().equals(Void.class)) { + return 0; + } + else { + return adapter.isMultiValue() ? 2 : 1; + } + } @Override protected void handleNoMatch(@Nullable RouteMatcher.Route destination, Message message) { @@ -306,7 +378,18 @@ protected void handleNoMatch(@Nullable RouteMatcher.Route destination, Message frameTypes = getHandlerMethods().keySet().stream() + .map(CompositeMessageCondition::getMessageConditions) + .filter(conditions -> conditions.get(1).getMatchingCondition(message) != null) + .map(conditions -> (RSocketFrameTypeMessageCondition) conditions.get(0)) + .flatMap(condition -> condition.getFrameTypes().stream()) + .collect(Collectors.toSet()); + + throw new MessageDeliveryException(frameTypes.isEmpty() ? + "No handler for destination '" + destination + "'" : + "Destination '" + destination + "' does not support " + frameType + ". " + + "Supported interaction(s): " + frameTypes); } /** diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java index 6c288d587d40..b175f0990890 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java @@ -137,22 +137,16 @@ public void subscriptionTimeErrorForHandleAndReply() { @Test public void errorSignalWithExceptionHandler() { - Mono result = requester.route("error-signal").data("foo").retrieveMono(String.class); + Flux result = requester.route("error-signal").data("foo").retrieveFlux(String.class); StepVerifier.create(result).expectNext("Handled 'bad input'").expectComplete().verify(Duration.ofSeconds(5)); } @Test public void ignoreInput() { - Flux result = requester.route("ignore-input").data("a").retrieveFlux(String.class); + Mono result = requester.route("ignore-input").data("a").retrieveMono(String.class); StepVerifier.create(result).expectNext("bar").thenCancel().verify(Duration.ofSeconds(5)); } - @Test - public void retrieveMonoFromFluxResponderMethod() { - Mono result = requester.route("request-stream").data("foo").retrieveMono(String.class); - StepVerifier.create(result).expectNext("foo-1").expectComplete().verify(Duration.ofSeconds(5)); - } - @Controller static class ServerController { @@ -188,11 +182,6 @@ public String handleIllegalArgument(IllegalArgumentException ex) { Mono ignoreInput() { return Mono.delay(Duration.ofMillis(10)).map(l -> "bar"); } - - @MessageMapping("request-stream") - Flux stream(String payload) { - return Flux.range(1,100).delayElements(Duration.ofMillis(10)).map(idx -> payload + "-" + idx); - } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index 3b823757d4e9..db15789167b8 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -159,13 +159,13 @@ public void echoChannel() { @Test public void voidReturnValue() { - Flux result = requester.route("void-return-value").data("Hello").retrieveFlux(String.class); + Mono result = requester.route("void-return-value").data("Hello").retrieveMono(String.class); StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5)); } @Test public void voidReturnValueFromExceptionHandler() { - Flux result = requester.route("void-return-value").data("bad").retrieveFlux(String.class); + Mono result = requester.route("void-return-value").data("bad").retrieveMono(String.class); StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5)); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageConditionTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageConditionTests.java index 417baf1274e0..715ba8f43fd5 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageConditionTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketFrameTypeMessageConditionTests.java @@ -26,6 +26,8 @@ import org.springframework.messaging.support.MessageBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition.CONNECT_CONDITION; +import static org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition.EMPTY_CONDITION; /** * Unit tests for {@link RSocketFrameTypeMessageCondition}. @@ -33,16 +35,37 @@ */ public class RSocketFrameTypeMessageConditionTests { + private static final RSocketFrameTypeMessageCondition FNF_RR_CONDITION = + new RSocketFrameTypeMessageCondition(FrameType.REQUEST_FNF, FrameType.REQUEST_RESPONSE); + + @Test public void getMatchingCondition() { Message message = message(FrameType.REQUEST_RESPONSE); - RSocketFrameTypeMessageCondition condition = condition(FrameType.REQUEST_FNF, FrameType.REQUEST_RESPONSE); - RSocketFrameTypeMessageCondition actual = condition.getMatchingCondition(message); + RSocketFrameTypeMessageCondition actual = FNF_RR_CONDITION.getMatchingCondition(message); assertThat(actual).isNotNull(); assertThat(actual.getFrameTypes()).hasSize(1).containsOnly(FrameType.REQUEST_RESPONSE); } + @Test + public void getMatchingConditionEmpty() { + Message message = message(FrameType.REQUEST_RESPONSE); + RSocketFrameTypeMessageCondition actual = EMPTY_CONDITION.getMatchingCondition(message); + + assertThat(actual).isNull(); + } + + @Test + public void combine() { + + assertThat(EMPTY_CONDITION.combine(CONNECT_CONDITION).getFrameTypes()) + .containsExactly(FrameType.SETUP, FrameType.METADATA_PUSH); + + assertThat(EMPTY_CONDITION.combine(new RSocketFrameTypeMessageCondition(FrameType.REQUEST_FNF)).getFrameTypes()) + .containsExactly(FrameType.REQUEST_FNF); + } + @Test public void compareTo() { Message message = message(null); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java index 52c124e07901..c7d0a80a1d87 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java @@ -21,6 +21,9 @@ import io.rsocket.frame.FrameType; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.codec.ByteArrayDecoder; @@ -170,6 +173,69 @@ private static void testMapping(Object controller, String... expectedPatterns) { } } + @Test + public void rejectConnectMappingMethodsThatCanReply() { + + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setHandlers(Collections.singletonList(new InvalidConnectMappingController())); + assertThatThrownBy(handler::afterPropertiesSet) + .hasMessage("Invalid @ConnectMapping method. " + + "Return type must be void or a void async type: " + + "public java.lang.String org.springframework.messaging.rsocket.annotation.support." + + "RSocketMessageHandlerTests$InvalidConnectMappingController.connectString()"); + + handler = new RSocketMessageHandler(); + handler.setHandlers(Collections.singletonList(new AnotherInvalidConnectMappingController())); + assertThatThrownBy(handler::afterPropertiesSet) + .hasMessage("Invalid @ConnectMapping method. " + + "Return type must be void or a void async type: " + + "public reactor.core.publisher.Mono " + + "org.springframework.messaging.rsocket.annotation.support." + + "RSocketMessageHandlerTests$AnotherInvalidConnectMappingController.connectString()"); + } + + @Test + public void ignoreFireAndForgetToHandlerThatCanReply() { + + InteractionMismatchController controller = new InteractionMismatchController(); + + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setHandlers(Collections.singletonList(controller)); + handler.afterPropertiesSet(); + + MessageHeaderAccessor headers = new MessageHeaderAccessor(); + headers.setLeaveMutable(true); + RouteMatcher.Route route = handler.getRouteMatcher().parseRoute("mono-string"); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + headers.setHeader(RSocketFrameTypeMessageCondition.FRAME_TYPE_HEADER, FrameType.REQUEST_FNF); + Message message = MessageBuilder.createMessage(Mono.empty(), headers.getMessageHeaders()); + + // Simply dropped and logged (error cannot propagate to client) + StepVerifier.create(handler.handleMessage(message)).expectComplete().verify(); + assertThat(controller.invokeCount).isEqualTo(0); + } + + @Test + public void rejectRequestResponseToStreamingHandler() { + + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setHandlers(Collections.singletonList(new InteractionMismatchController())); + handler.afterPropertiesSet(); + + MessageHeaderAccessor headers = new MessageHeaderAccessor(); + headers.setLeaveMutable(true); + RouteMatcher.Route route = handler.getRouteMatcher().parseRoute("flux-string"); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + headers.setHeader(RSocketFrameTypeMessageCondition.FRAME_TYPE_HEADER, FrameType.REQUEST_RESPONSE); + Message message = MessageBuilder.createMessage(Mono.empty(), headers.getMessageHeaders()); + + StepVerifier.create(handler.handleMessage(message)) + .expectErrorMessage( + "Destination 'flux-string' does not support REQUEST_RESPONSE. " + + "Supported interaction(s): [REQUEST_STREAM]") + .verify(); + } + @Test public void handleNoMatch() { @@ -222,4 +288,38 @@ public void handleAll() { } } + + private static class InvalidConnectMappingController { + + @ConnectMapping + public String connectString() { + return ""; + } + } + + private static class AnotherInvalidConnectMappingController { + + @ConnectMapping + public Mono connectString() { + return Mono.empty(); + } + } + + private static class InteractionMismatchController { + + private int invokeCount; + + @MessageMapping("mono-string") + public Mono messageMonoString() { + this.invokeCount++; + return Mono.empty(); + } + + @MessageMapping("flux-string") + public Flux messageFluxString() { + this.invokeCount++; + return Flux.empty(); + } + } + } diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt index b02efc11d188..cbc5a5a77a73 100644 --- a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt @@ -84,13 +84,13 @@ class RSocketClientToServerCoroutinesIntegrationTests { @Test fun unitReturnValue() { - val result = requester.route("unit-return-value").data("Hello").retrieveFlux(String::class.java) + val result = requester.route("unit-return-value").data("Hello").retrieveMono(String::class.java) StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5)) } @Test fun unitReturnValueFromExceptionHandler() { - val result = requester.route("unit-return-value").data("bad").retrieveFlux(String::class.java) + val result = requester.route("unit-return-value").data("bad").retrieveMono(String::class.java) StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5)) } diff --git a/src/docs/asciidoc/rsocket.adoc b/src/docs/asciidoc/rsocket.adoc index 34a1e23e7274..752699289eb9 100644 --- a/src/docs/asciidoc/rsocket.adoc +++ b/src/docs/asciidoc/rsocket.adoc @@ -586,7 +586,7 @@ indicates only that the message was successfully sent, and not that it was handl == Annotated Responders RSocket responders can be implemented as `@MessageMapping` and `@ConnectMapping` methods. -`@MessageMapping` methods handle individual requests, and `@ConnectMapping` methods handle +`@MessageMapping` methods handle individual requests while `@ConnectMapping` methods handle connection-level events (setup and metadata push). Annotated responders are supported symmetrically, for responding from the server side and for responding from the client side. @@ -760,20 +760,90 @@ class RadarsController { } ---- -You don't need to explicit specify the RSocket interaction type. Simply declare the -expected input and output, and a route pattern. The supporting infrastructure will adapt -matching requests. +The above `@MessageMapping` method responds to a Request-Stream interaction having the +route "locate.radars.within". It supports a flexible method signature with the option to +use the following method arguments: -The following additional arguments are supported for `@MessageMapping` methods: +[cols="1,3",options="header"] +|=== +| Method Argument +| Description -* `RSocketRequester` -- the requester for the connection associated with the request, - to make requests to the remote end. -* `@DestinationVariable` -- the value for a variable from the pattern, e.g. +| `@Payload` +| The payload of the request. This can be a concrete value of asynchronous types like + `Mono` or `Flux`. + + *Note:* Use of the annotation is optional. A method argument that is not a simple type + and is not any of the other supported arguments, is assumed to be the expected payload. + +| `RSocketRequester` +| Requester for making requests to the remote end. + +| `@DestinationVariable` +| Value extracted from the route based on variables in the mapping pattern, e.g. `@MessageMapping("find.radar.{id}")`. -* `@Header` -- access to a metadata value registered for extraction, as described in - <>. -* `@Headers Map` -- access to all metadata values registered for - extraction, as described in <>. + +| `@Header` +| Metadata value registered for extraction as described in <>. + +| `@Headers Map` +| All metadata values registered for extraction as described in <>. + +|=== + +The return value is expected to be one or more Objects to be serialized as response +payloads. That can be asynchronous types like `Mono` or `Flux`, a concrete value, or +either `void` or a no-value asynchronous type such as `Mono`. + +The RSocket interaction type that an `@MessageMapping` method supports is determined from +the cardinality of the input (i.e. `@Payload` argument) and of the output, where +cardinality means the following: + +[%autowidth] +[cols=2*,options="header"] +|=== +| Cardinality +| Description + +| 1 +| Either an explicit value, or a single-value asynchronous type such as `Mono`. + +| Many +| A multi-value asynchronous type such as `Flux`. + +| 0 +| For input this means the method does not have an `@Payload` argument. + + For output this is `void` or a no-value asynchronous type such as `Mono`. +|=== + +The table below shows all input and output cardinality combinations and the corresponding +interaction type(s): + +[%autowidth] +[cols=3*,options="header"] +|=== +| Input Cardinality +| Output Cardinality +| Interaction Types + +| 0, 1 +| 0 +| Fire-and-Forget, Request-Response + +| 0, 1 +| 1 +| Request-Response + +| 0, 1 +| Many +| Request-Stream + +| Many +| 0, 1, Many +| Request-Channel + +|===