diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java index 1b117c941ef2..fcec3a4c1302 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java @@ -23,8 +23,10 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Function; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.ClientTransport; @@ -174,10 +176,6 @@ public Mono connectWebSocket(URI uri) { @Override public Mono connect(ClientTransport transport) { - return Mono.defer(() -> doConnect(transport)); - } - - private Mono doConnect(ClientTransport transport) { RSocketStrategies rsocketStrategies = getRSocketStrategies(); Assert.isTrue(!rsocketStrategies.encoders().isEmpty(), "No encoders"); Assert.isTrue(!rsocketStrategies.decoders().isEmpty(), "No decoders"); @@ -186,21 +184,28 @@ private Mono doConnect(ClientTransport transport) { MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); MimeType dataMimeType = getDataMimeType(rsocketStrategies); + Mono setupPayload = getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies); + Function> connectFunction; if (rsocketConnectorPresent) { - return getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies) - .flatMap(payload -> - new RSocketConnectorHelper().connect( - this.rsocketConnectorConfigurers, this.rsocketFactoryConfigurers, - metaMimeType, dataMimeType, payload, rsocketStrategies, transport)); + connectFunction = payload -> new RSocketConnectorHelper().getRSocketMono( + this.rsocketConnectorConfigurers, this.rsocketFactoryConfigurers, + metaMimeType, dataMimeType, setupPayload, rsocketStrategies, transport, payload); } else { - return getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies) - .flatMap(payload -> - new RSocketFactoryHelper().connect( - this.rsocketFactoryConfigurers, metaMimeType, dataMimeType, payload, - rsocketStrategies, transport)); + connectFunction = payload -> new RSocketFactoryHelper().getRSocketMono( + this.rsocketFactoryConfigurers, metaMimeType, dataMimeType, + setupPayload, rsocketStrategies, transport, payload); } + + // In RSocket 1.0.2 we can pass a Mono for the setup Payload. Until then we have to + // resolve it and then cache the Mono because it may be a ReconnectMono. + + return setupPayload + .map(connectFunction) + .cache() + .flatMap(mono -> mono.map(rsocket -> + new DefaultRSocketRequester(rsocket, dataMimeType, metaMimeType, rsocketStrategies))); } private RSocketStrategies getRSocketStrategies() { @@ -285,14 +290,13 @@ private Mono getSetupPayload( } + @SuppressWarnings("deprecation") private static class RSocketConnectorHelper { - @SuppressWarnings("deprecation") - Mono connect( - List connectorConfigurers, + Mono getRSocketMono(List connectorConfigurers, List factoryConfigurers, - MimeType metaMimeType, MimeType dataMimeType, Payload setupPayload, - RSocketStrategies rsocketStrategies, ClientTransport transport) { + MimeType metaMimeType, MimeType dataMimeType, Mono setupPayload, + RSocketStrategies rsocketStrategies, ClientTransport transport, Payload payload) { io.rsocket.core.RSocketConnector connector = io.rsocket.core.RSocketConnector.create(); connectorConfigurers.forEach(c -> c.configure(connector)); @@ -307,16 +311,13 @@ Mono connect( connector.payloadDecoder(PayloadDecoder.ZERO_COPY); } + connector.metadataMimeType(metaMimeType.toString()); + connector.dataMimeType(dataMimeType.toString()); + if (setupPayload != EMPTY_SETUP_PAYLOAD) { - connector.setupPayload(setupPayload); + connector.setupPayload(payload); } - - return connector - .metadataMimeType(metaMimeType.toString()) - .dataMimeType(dataMimeType.toString()) - .connect(transport) - .map(rsocket -> new DefaultRSocketRequester( - rsocket, dataMimeType, metaMimeType, rsocketStrategies)); + return connector.connect(transport); } } @@ -324,10 +325,9 @@ Mono connect( @SuppressWarnings("deprecation") private static class RSocketFactoryHelper { - Mono connect( - List configurers, - MimeType metaMimeType, MimeType dataMimeType, Payload setupPayload, - RSocketStrategies rsocketStrategies, ClientTransport transport) { + Mono getRSocketMono(List configurers, + MimeType metaMimeType, MimeType dataMimeType, Mono setupPayload, + RSocketStrategies rsocketStrategies, ClientTransport transport, Payload payload) { io.rsocket.RSocketFactory.ClientRSocketFactory factory = io.rsocket.RSocketFactory.connect(); configurers.forEach(c -> c.configure(factory)); @@ -336,16 +336,12 @@ Mono connect( factory.frameDecoder(PayloadDecoder.ZERO_COPY); } + factory.metadataMimeType(metaMimeType.toString()); + factory.dataMimeType(dataMimeType.toString()); if (setupPayload != EMPTY_SETUP_PAYLOAD) { - factory.setupPayload(setupPayload); + factory.setupPayload(payload); } - - return factory.metadataMimeType(metaMimeType.toString()) - .dataMimeType(dataMimeType.toString()) - .transport(transport) - .start() - .map(rsocket -> new DefaultRSocketRequester( - rsocket, dataMimeType, metaMimeType, rsocketStrategies)); + return factory.transport(transport).start(); } }