diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index 249ce6106e7b..be879bbcfbeb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -57,7 +57,7 @@ final class DefaultRSocketRequester implements RSocketRequester { private final RSocketStrategies strategies; - private final DataBuffer emptyDataBuffer; + private final Mono emptyBufferMono; DefaultRSocketRequester( @@ -73,7 +73,7 @@ final class DefaultRSocketRequester implements RSocketRequester { this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; this.strategies = strategies; - this.emptyDataBuffer = this.strategies.dataBufferFactory().wrap(new byte[0]); + this.emptyBufferMono = Mono.just(this.strategies.dataBufferFactory().wrap(new byte[0])); } @@ -193,7 +193,7 @@ else if (adapter != null) { } if (isVoid(elementType) || (adapter != null && adapter.isNoValue())) { - this.payloadMono = firstPayload(Mono.when(publisher).then(Mono.just(emptyDataBuffer))); + this.payloadMono = Mono.when(publisher).then(firstPayload(emptyBufferMono)); this.payloadFlux = null; return; } @@ -204,7 +204,7 @@ else if (adapter != null) { if (adapter != null && !adapter.isMultiValue()) { Mono data = Mono.from(publisher) .map(value -> encodeData(value, elementType, encoder)) - .defaultIfEmpty(emptyDataBuffer); + .switchIfEmpty(emptyBufferMono); this.payloadMono = firstPayload(data); this.payloadFlux = null; return; @@ -213,7 +213,7 @@ else if (adapter != null) { this.payloadMono = null; this.payloadFlux = Flux.from(publisher) .map(value -> encodeData(value, elementType, encoder)) - .defaultIfEmpty(emptyDataBuffer) + .switchIfEmpty(emptyBufferMono) .switchOnFirst((signal, inner) -> { DataBuffer data = signal.get(); if (data != null) { @@ -250,12 +250,7 @@ private Mono firstPayload(Mono encodedData) { @Override public Mono send() { - return getPayloadMonoRequired().flatMap(rsocket::fireAndForget); - } - - private Mono getPayloadMonoRequired() { - Assert.state(this.payloadFlux == null, "No RSocket interaction model for Flux request to Mono response."); - return this.payloadMono != null ? this.payloadMono : firstPayload(Mono.just(emptyDataBuffer)); + return getPayloadMono().flatMap(rsocket::fireAndForget); } @Override @@ -268,19 +263,9 @@ public Mono retrieveMono(ParameterizedTypeReference dataTypeRef) { return retrieveMono(ResolvableType.forType(dataTypeRef)); } - @Override - public Flux retrieveFlux(Class dataType) { - return retrieveFlux(ResolvableType.forClass(dataType)); - } - - @Override - public Flux retrieveFlux(ParameterizedTypeReference dataTypeRef) { - return retrieveFlux(ResolvableType.forType(dataTypeRef)); - } - @SuppressWarnings("unchecked") private Mono retrieveMono(ResolvableType elementType) { - Mono payloadMono = getPayloadMonoRequired().flatMap(rsocket::requestResponse); + Mono payloadMono = getPayloadMono().flatMap(rsocket::requestResponse); if (isVoid(elementType)) { return (Mono) payloadMono.then(); @@ -291,11 +276,22 @@ private Mono retrieveMono(ResolvableType elementType) { .map(dataBuffer -> decoder.decode(dataBuffer, elementType, dataMimeType, EMPTY_HINTS)); } + @Override + public Flux retrieveFlux(Class dataType) { + return retrieveFlux(ResolvableType.forClass(dataType)); + } + + @Override + public Flux retrieveFlux(ParameterizedTypeReference dataTypeRef) { + return retrieveFlux(ResolvableType.forType(dataTypeRef)); + } + @SuppressWarnings("unchecked") private Flux retrieveFlux(ResolvableType elementType) { - Flux payloadFlux = this.payloadMono != null ? - this.payloadMono.flatMapMany(rsocket::requestStream) : - rsocket.requestChannel(this.payloadFlux); + + Flux payloadFlux = (this.payloadFlux != null ? + rsocket.requestChannel(this.payloadFlux) : + getPayloadMono().flatMapMany(rsocket::requestStream)); if (isVoid(elementType)) { return payloadFlux.thenMany(Flux.empty()); @@ -306,6 +302,11 @@ private Flux retrieveFlux(ResolvableType elementType) { (T) decoder.decode(dataBuffer, elementType, dataMimeType, EMPTY_HINTS)); } + private Mono getPayloadMono() { + Assert.state(this.payloadFlux == null, "No RSocket interaction with Flux request and Mono response."); + return this.payloadMono != null ? this.payloadMono : firstPayload(emptyBufferMono); + } + private DataBuffer retainDataAndReleasePayload(Payload payload) { return PayloadUtils.retainDataAndReleasePayload(payload, bufferFactory()); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java index 7392d901eeb3..878f3776c76f 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java @@ -145,15 +145,6 @@ public void sendWithoutData() { assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); } - @Test - public void sendMonoWithoutData() { - this.requester.route("toA").retrieveMono(String.class).block(Duration.ofSeconds(5)); - - assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestResponse"); - assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); - assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); - } - @Test public void testSendWithAsyncMetadata() { @@ -205,6 +196,15 @@ public void retrieveMonoVoid() { assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestResponse"); } + @Test + public void retrieveMonoWithoutData() { + this.requester.route("toA").retrieveMono(String.class).block(Duration.ofSeconds(5)); + + assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestResponse"); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); + } + @Test public void retrieveFlux() { String[] values = new String[] {"bodyA", "bodyB", "bodyC"}; @@ -227,11 +227,20 @@ public void retrieveFluxVoid() { assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestStream"); } + @Test + public void retrieveFluxWithoutData() { + this.requester.route("toA").retrieveFlux(String.class).blockLast(Duration.ofSeconds(5)); + + assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestStream"); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); + } + @Test public void fluxToMonoIsRejected() { assertThatIllegalStateException() .isThrownBy(() -> this.requester.route("").data(Flux.just("a", "b")).retrieveMono(String.class)) - .withMessage("No RSocket interaction model for Flux request to Mono response."); + .withMessage("No RSocket interaction with Flux request and Mono response."); } private Payload toPayload(String value) {