From 02cca0976f68db240867c7016d1a58b35cfd4a4b Mon Sep 17 00:00:00 2001 From: uglycow Date: Wed, 20 Mar 2019 10:56:44 +0800 Subject: [PATCH] Merge pull request #3644, 3.x dev rx support. * reative support * rsocket support. support using Mono and Flux as return value. * reformat code, remove unused import, add license * optimize import * remove author * support using Mono/Flux as args * remove unused import --- .../rpc/protocol/rsocket/RSocketInvoker.java | 515 +++---- .../rpc/protocol/rsocket/RSocketProtocol.java | 1258 ++++++++++------- .../protocol/rsocket/ResourceDirectory.java | 62 + .../rpc/protocol/rsocket/ResourceInfo.java | 41 + .../protocol/rsocket/RSocketProtocolTest.java | 40 + .../apache/dubbo/rpc/service/DemoService.java | 4 + .../dubbo/rpc/service/DemoServiceImpl.java | 21 + 7 files changed, 1163 insertions(+), 778 deletions(-) create mode 100644 dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceDirectory.java create mode 100644 dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceInfo.java diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketInvoker.java b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketInvoker.java index 98c7d874688..572f429c0a6 100644 --- a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketInvoker.java +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketInvoker.java @@ -1,248 +1,267 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.dubbo.rpc.protocol.rsocket; - -import io.rsocket.Payload; -import io.rsocket.RSocket; -import io.rsocket.util.DefaultPayload; -import org.apache.dubbo.common.Constants; -import org.apache.dubbo.common.URL; -import org.apache.dubbo.common.serialize.Cleanable; -import org.apache.dubbo.common.serialize.ObjectInput; -import org.apache.dubbo.common.serialize.ObjectOutput; -import org.apache.dubbo.common.serialize.Serialization; -import org.apache.dubbo.common.utils.AtomicPositiveInteger; -import org.apache.dubbo.common.utils.ReflectUtils; -import org.apache.dubbo.remoting.transport.CodecSupport; -import org.apache.dubbo.rpc.Invocation; -import org.apache.dubbo.rpc.Invoker; -import org.apache.dubbo.rpc.Result; -import org.apache.dubbo.rpc.RpcContext; -import org.apache.dubbo.rpc.RpcException; -import org.apache.dubbo.rpc.RpcInvocation; -import org.apache.dubbo.rpc.RpcResult; -import org.apache.dubbo.rpc.protocol.AbstractInvoker; -import org.apache.dubbo.rpc.support.RpcUtils; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Function; - -public class RSocketInvoker extends AbstractInvoker { - - private final RSocket[] clients; - - private final AtomicPositiveInteger index = new AtomicPositiveInteger(); - - private final String version; - - private final ReentrantLock destroyLock = new ReentrantLock(); - - private final Set> invokers; - - private final Serialization serialization; - - public RSocketInvoker(Class serviceType, URL url, RSocket[] clients, Set> invokers) { - super(serviceType, url, new String[]{Constants.INTERFACE_KEY, Constants.GROUP_KEY, Constants.TOKEN_KEY, Constants.TIMEOUT_KEY}); - this.clients = clients; - // get version. - this.version = url.getParameter(Constants.VERSION_KEY, "0.0.0"); - this.invokers = invokers; - - this.serialization = CodecSupport.getSerialization(getUrl()); - } - - @Override - protected Result doInvoke(final Invocation invocation) throws Throwable { - RpcInvocation inv = (RpcInvocation) invocation; - final String methodName = RpcUtils.getMethodName(invocation); - inv.setAttachment(Constants.PATH_KEY, getUrl().getPath()); - inv.setAttachment(Constants.VERSION_KEY, version); - - RSocket currentClient; - if (clients.length == 1) { - currentClient = clients[0]; - } else { - currentClient = clients[index.getAndIncrement() % clients.length]; - } - try { - //TODO support timeout - int timeout = getUrl().getMethodParameter(methodName, Constants.TIMEOUT_KEY, Constants.DEFAULT_TIMEOUT); - - RpcContext.getContext().setFuture(null); - //encode inv: metadata and data(arg,attachment) - Payload requestPayload = encodeInvocation(invocation); - - Class retType = RpcUtils.getReturnType(invocation); - - if (retType != null && retType.isAssignableFrom(Mono.class)) { - Mono responseMono = currentClient.requestResponse(requestPayload); - Mono bizMono = responseMono.map(new Function() { - @Override - public Object apply(Payload payload) { - return decodeData(payload); - } - }); - RpcResult rpcResult = new RpcResult(); - rpcResult.setValue(bizMono); - return rpcResult; - } else if (retType != null && retType.isAssignableFrom(Flux.class)) { - return requestStream(currentClient, requestPayload); - } else { - //request-reponse - Mono responseMono = currentClient.requestResponse(requestPayload); - FutureSubscriber futureSubscriber = new FutureSubscriber(serialization, retType); - responseMono.subscribe(futureSubscriber); - return (Result) futureSubscriber.get(); - } - - //TODO support stream arg - } catch (Throwable t) { - throw new RpcException(t); - } - } - - - private Result requestStream(RSocket currentClient, Payload requestPayload) { - Flux responseFlux = currentClient.requestStream(requestPayload); - Flux retFlux = responseFlux.map(new Function() { - - @Override - public Object apply(Payload payload) { - return decodeData(payload); - } - }); - - RpcResult rpcResult = new RpcResult(); - rpcResult.setValue(retFlux); - return rpcResult; - } - - - private Object decodeData(Payload payload) { - try { - //TODO save the copy - ByteBuffer dataBuffer = payload.getData(); - byte[] dataBytes = new byte[dataBuffer.remaining()]; - dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); - InputStream dataInputStream = new ByteArrayInputStream(dataBytes); - ObjectInput in = serialization.deserialize(null, dataInputStream); - int flag = in.readByte(); - if ((flag & RSocketConstants.FLAG_ERROR) != 0) { - Throwable t = (Throwable) in.readObject(); - throw t; - } else { - return in.readObject(); - } - } catch (Throwable t) { - throw Exceptions.propagate(t); - } - } - - @Override - public boolean isAvailable() { - if (!super.isAvailable()) { - return false; - } - for (RSocket client : clients) { - if (client.availability() > 0) { - return true; - } - } - return false; - } - - @Override - public void destroy() { - // in order to avoid closing a client multiple times, a counter is used in case of connection per jvm, every - // time when client.close() is called, counter counts down once, and when counter reaches zero, client will be - // closed. - if (super.isDestroyed()) { - return; - } else { - // double check to avoid dup close - destroyLock.lock(); - try { - if (super.isDestroyed()) { - return; - } - super.destroy(); - if (invokers != null) { - invokers.remove(this); - } - for (RSocket client : clients) { - try { - client.dispose(); - } catch (Throwable t) { - logger.warn(t.getMessage(), t); - } - } - - } finally { - destroyLock.unlock(); - } - } - } - - private Payload encodeInvocation(Invocation invocation) throws IOException { - byte[] metadata = encodeMetadata(invocation); - byte[] data = encodeData(invocation); - return DefaultPayload.create(data, metadata); - } - - private byte[] encodeMetadata(Invocation invocation) throws IOException { - Map metadataMap = new HashMap(); - metadataMap.put(RSocketConstants.SERVICE_NAME_KEY, invocation.getAttachment(Constants.PATH_KEY)); - metadataMap.put(RSocketConstants.SERVICE_VERSION_KEY, invocation.getAttachment(Constants.VERSION_KEY)); - metadataMap.put(RSocketConstants.METHOD_NAME_KEY, invocation.getMethodName()); - metadataMap.put(RSocketConstants.PARAM_TYPE_KEY, ReflectUtils.getDesc(invocation.getParameterTypes())); - metadataMap.put(RSocketConstants.SERIALIZE_TYPE_KEY, (Byte) serialization.getContentTypeId()); - return MetadataCodec.encodeMetadata(metadataMap); - } - - - private byte[] encodeData(Invocation invocation) throws IOException { - ByteArrayOutputStream dataOutputStream = new ByteArrayOutputStream(); - Serialization serialization = CodecSupport.getSerialization(getUrl()); - ObjectOutput out = serialization.serialize(getUrl(), dataOutputStream); - RpcInvocation inv = (RpcInvocation) invocation; - Object[] args = inv.getArguments(); - if (args != null) { - for (int i = 0; i < args.length; i++) { - out.writeObject(args[i]); - } - } - out.writeObject(RpcUtils.getNecessaryAttachments(inv)); - - //clean - out.flushBuffer(); - if (out instanceof Cleanable) { - ((Cleanable) out).cleanup(); - } - return dataOutputStream.toByteArray(); - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.dubbo.rpc.protocol.rsocket; + +import org.apache.dubbo.common.Constants; +import org.apache.dubbo.common.URL; +import org.apache.dubbo.common.serialize.Cleanable; +import org.apache.dubbo.common.serialize.ObjectInput; +import org.apache.dubbo.common.serialize.ObjectOutput; +import org.apache.dubbo.common.serialize.Serialization; +import org.apache.dubbo.common.utils.AtomicPositiveInteger; +import org.apache.dubbo.common.utils.ReflectUtils; +import org.apache.dubbo.remoting.transport.CodecSupport; +import org.apache.dubbo.rpc.Invocation; +import org.apache.dubbo.rpc.Invoker; +import org.apache.dubbo.rpc.Result; +import org.apache.dubbo.rpc.RpcContext; +import org.apache.dubbo.rpc.RpcException; +import org.apache.dubbo.rpc.RpcInvocation; +import org.apache.dubbo.rpc.RpcResult; +import org.apache.dubbo.rpc.protocol.AbstractInvoker; +import org.apache.dubbo.rpc.support.RpcUtils; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.DefaultPayload; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +public class RSocketInvoker extends AbstractInvoker { + + private final RSocket[] clients; + + private final AtomicPositiveInteger index = new AtomicPositiveInteger(); + + private final String version; + + private final ReentrantLock destroyLock = new ReentrantLock(); + + private final Set> invokers; + + private final Serialization serialization; + + public RSocketInvoker(Class serviceType, URL url, RSocket[] clients, Set> invokers) { + super(serviceType, url, new String[]{Constants.INTERFACE_KEY, Constants.GROUP_KEY, Constants.TOKEN_KEY, Constants.TIMEOUT_KEY}); + this.clients = clients; + // get version. + this.version = url.getParameter(Constants.VERSION_KEY, "0.0.0"); + this.invokers = invokers; + + this.serialization = CodecSupport.getSerialization(getUrl()); + } + + @Override + protected Result doInvoke(final Invocation invocation) throws Throwable { + RpcInvocation inv = (RpcInvocation) invocation; + final String methodName = RpcUtils.getMethodName(invocation); + inv.setAttachment(Constants.PATH_KEY, getUrl().getPath()); + inv.setAttachment(Constants.VERSION_KEY, version); + + RSocket currentClient; + if (clients.length == 1) { + currentClient = clients[0]; + } else { + currentClient = clients[index.getAndIncrement() % clients.length]; + } + try { + //TODO support timeout + int timeout = getUrl().getMethodParameter(methodName, Constants.TIMEOUT_KEY, Constants.DEFAULT_TIMEOUT); + + Class retType = RpcUtils.getReturnType(invocation); + + RpcContext.getContext().setFuture(null); + //encode inv: metadata and data(arg,attachment) + Payload requestPayload = encodeInvocation(invocation); + + if (retType != null && retType.isAssignableFrom(Mono.class)) { + Mono responseMono = currentClient.requestResponse(requestPayload); + Mono bizMono = responseMono.map(new Function() { + @Override + public Object apply(Payload payload) { + return decodeData(payload); + } + }); + RpcResult rpcResult = new RpcResult(); + rpcResult.setValue(bizMono); + return rpcResult; + } else if (retType != null && retType.isAssignableFrom(Flux.class)) { + return requestStream(currentClient, requestPayload); + } else { + //request-reponse + Mono responseMono = currentClient.requestResponse(requestPayload); + FutureSubscriber futureSubscriber = new FutureSubscriber(serialization, retType); + responseMono.subscribe(futureSubscriber); + return (Result) futureSubscriber.get(); + } + + //TODO support stream arg + } catch (Throwable t) { + throw new RpcException(t); + } + } + + + private Result requestStream(RSocket currentClient, Payload requestPayload) { + Flux responseFlux = currentClient.requestStream(requestPayload); + Flux retFlux = responseFlux.map(new Function() { + + @Override + public Object apply(Payload payload) { + Object o = decodeData(payload); + payload.release(); + return o; + } + }); + + RpcResult rpcResult = new RpcResult(); + rpcResult.setValue(retFlux); + return rpcResult; + } + + + private Object decodeData(Payload payload) { + try { + ByteBuffer dataBuffer = payload.getData(); + byte[] dataBytes = new byte[dataBuffer.remaining()]; + dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); + InputStream dataInputStream = new ByteArrayInputStream(dataBytes); + ObjectInput in = serialization.deserialize(null, dataInputStream); + //TODO save the copy + int flag = in.readByte(); + if ((flag & RSocketConstants.FLAG_ERROR) != 0) { + Throwable t = (Throwable) in.readObject(); + throw t; + } else { + return in.readObject(); + } + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + + @Override + public boolean isAvailable() { + if (!super.isAvailable()) { + return false; + } + for (RSocket client : clients) { + if (client.availability() > 0) { + return true; + } + } + return false; + } + + @Override + public void destroy() { + // in order to avoid closing a client multiple times, a counter is used in case of connection per jvm, every + // time when client.close() is called, counter counts down once, and when counter reaches zero, client will be + // closed. + if (super.isDestroyed()) { + return; + } else { + // double check to avoid dup close + destroyLock.lock(); + try { + if (super.isDestroyed()) { + return; + } + super.destroy(); + if (invokers != null) { + invokers.remove(this); + } + for (RSocket client : clients) { + try { + client.dispose(); + } catch (Throwable t) { + logger.warn(t.getMessage(), t); + } + } + + } finally { + destroyLock.unlock(); + } + } + } + + private Payload encodeInvocation(Invocation invocation) throws IOException { + //process stream args + RpcInvocation inv = (RpcInvocation) invocation; + Class[] parameterTypes = invocation.getParameterTypes(); + Object[] args = inv.getArguments(); + if (args != null) { + for (int i = 0; i < args.length; i++) { + if(args[i]!=null) { + Class argClass = args[i].getClass(); + if (Mono.class.isAssignableFrom(argClass)) { + long id = ResourceDirectory.mountResource(args[i]); + args[i] = new ResourceInfo(id, ResourceInfo.RESOURCE_TYPE_MONO); + parameterTypes[i] = ResourceInfo.class; + } else if (Flux.class.isAssignableFrom(argClass)) { + long id = ResourceDirectory.mountResource(args[i]); + args[i] = new ResourceInfo(id, ResourceInfo.RESOURCE_TYPE_FLUX); + parameterTypes[i] = ResourceInfo.class; + } + } + } + } + + //metadata + Map metadataMap = new HashMap(); + metadataMap.put(RSocketConstants.SERVICE_NAME_KEY, invocation.getAttachment(Constants.PATH_KEY)); + metadataMap.put(RSocketConstants.SERVICE_VERSION_KEY, invocation.getAttachment(Constants.VERSION_KEY)); + metadataMap.put(RSocketConstants.METHOD_NAME_KEY, invocation.getMethodName()); + metadataMap.put(RSocketConstants.SERIALIZE_TYPE_KEY, (Byte) serialization.getContentTypeId()); + metadataMap.put(RSocketConstants.PARAM_TYPE_KEY, ReflectUtils.getDesc(parameterTypes)); + byte[] metadata = MetadataCodec.encodeMetadata(metadataMap); + + + //data + ByteArrayOutputStream dataOutputStream = new ByteArrayOutputStream(); + Serialization serialization = CodecSupport.getSerialization(getUrl()); + ObjectOutput out = serialization.serialize(getUrl(), dataOutputStream); + if (args != null) { + for (int i = 0; i < args.length; i++) { + out.writeObject(args[i]); + } + } + out.writeObject(RpcUtils.getNecessaryAttachments(inv)); + + //clean + out.flushBuffer(); + if (out instanceof Cleanable) { + ((Cleanable) out).cleanup(); + } + byte[] data = dataOutputStream.toByteArray(); + + + return DefaultPayload.create(data, metadata); + } +} diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocol.java b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocol.java index c64d5176e3f..be537eb48ba 100644 --- a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocol.java +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocol.java @@ -1,530 +1,728 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.dubbo.rpc.protocol.rsocket; - -import io.rsocket.AbstractRSocket; -import io.rsocket.ConnectionSetupPayload; -import io.rsocket.Payload; -import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; -import io.rsocket.SocketAcceptor; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.CloseableChannel; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.DefaultPayload; -import org.apache.dubbo.common.Constants; -import org.apache.dubbo.common.URL; -import org.apache.dubbo.common.extension.ExtensionLoader; -import org.apache.dubbo.common.logger.Logger; -import org.apache.dubbo.common.logger.LoggerFactory; -import org.apache.dubbo.common.serialize.ObjectInput; -import org.apache.dubbo.common.serialize.ObjectOutput; -import org.apache.dubbo.common.utils.NetUtils; -import org.apache.dubbo.common.utils.ReflectUtils; -import org.apache.dubbo.remoting.RemotingException; -import org.apache.dubbo.remoting.transport.CodecSupport; -import org.apache.dubbo.rpc.Exporter; -import org.apache.dubbo.rpc.Invocation; -import org.apache.dubbo.rpc.Invoker; -import org.apache.dubbo.rpc.Protocol; -import org.apache.dubbo.rpc.Result; -import org.apache.dubbo.rpc.RpcException; -import org.apache.dubbo.rpc.RpcInvocation; -import org.apache.dubbo.rpc.protocol.AbstractProtocol; -import org.apache.dubbo.rpc.support.RpcUtils; -import org.reactivestreams.Publisher; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.function.Function; - -public class RSocketProtocol extends AbstractProtocol { - - public static final String NAME = "rsocket"; - public static final int DEFAULT_PORT = 30880; - private static final Logger log = LoggerFactory.getLogger(RSocketProtocol.class); - private static RSocketProtocol INSTANCE; - - // - private final Map serverMap = new ConcurrentHashMap(); - - // - private final Map referenceClientMap = new ConcurrentHashMap(); - - private final ConcurrentMap locks = new ConcurrentHashMap(); - - public RSocketProtocol() { - INSTANCE = this; - } - - public static RSocketProtocol getRSocketProtocol() { - if (INSTANCE == null) { - ExtensionLoader.getExtensionLoader(Protocol.class).getExtension(RSocketProtocol.NAME); // load - } - return INSTANCE; - } - - public Collection> getExporters() { - return Collections.unmodifiableCollection(exporterMap.values()); - } - - Map> getExporterMap() { - return exporterMap; - } - - Invoker getInvoker(int port, Map metadataMap) throws RemotingException { - String path = (String) metadataMap.get(RSocketConstants.SERVICE_NAME_KEY); - String serviceKey = serviceKey(port, path, (String) metadataMap.get(RSocketConstants.SERVICE_VERSION_KEY), (String) metadataMap.get(Constants.GROUP_KEY)); - RSocketExporter exporter = (RSocketExporter) exporterMap.get(serviceKey); - if (exporter == null) { - //throw new Throwable("Not found exported service: " + serviceKey + " in " + exporterMap.keySet() + ", may be version or group mismatch " + ", channel: consumer: " + channel.getRemoteAddress() + " --> provider: " + channel.getLocalAddress() + ", message:" + inv); - throw new RuntimeException("Not found exported service: " + serviceKey + " in " + exporterMap.keySet() + ", may be version or group mismatch "); - } - - return exporter.getInvoker(); - } - - public Collection> getInvokers() { - return Collections.unmodifiableCollection(invokers); - } - - @Override - public int getDefaultPort() { - return DEFAULT_PORT; - } - - @Override - public Exporter export(Invoker invoker) throws RpcException { - URL url = invoker.getUrl(); - - // export service. - String key = serviceKey(url); - RSocketExporter exporter = new RSocketExporter(invoker, key, exporterMap); - exporterMap.put(key, exporter); - - openServer(url); - return exporter; - } - - private void openServer(URL url) { - String key = url.getAddress(); - //client can export a service which's only for server to invoke - boolean isServer = url.getParameter(Constants.IS_SERVER_KEY, true); - if (isServer) { - CloseableChannel server = serverMap.get(key); - if (server == null) { - synchronized (this) { - server = serverMap.get(key); - if (server == null) { - serverMap.put(key, createServer(url)); - } - } - } - } - } - - private CloseableChannel createServer(URL url) { - try { - String bindIp = url.getParameter(Constants.BIND_IP_KEY, url.getHost()); - int bindPort = url.getParameter(Constants.BIND_PORT_KEY, url.getPort()); - if (url.getParameter(Constants.ANYHOST_KEY, false) || NetUtils.isInvalidLocalHost(bindIp)) { - bindIp = NetUtils.ANYHOST; - } - return RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl(bindPort)) - .transport(TcpServerTransport.create(bindIp, bindPort)) - .start() - .block(); - } catch (Throwable e) { - throw new RpcException("Fail to start server(url: " + url + ") " + e.getMessage(), e); - } - } - - - @Override - public Invoker refer(Class serviceType, URL url) throws RpcException { - // create rpc invoker. - RSocketInvoker invoker = new RSocketInvoker(serviceType, url, getClients(url), invokers); - invokers.add(invoker); - return invoker; - } - - private RSocket[] getClients(URL url) { - // whether to share connection - boolean service_share_connect = false; - int connections = url.getParameter(Constants.CONNECTIONS_KEY, 0); - // if not configured, connection is shared, otherwise, one connection for one service - if (connections == 0) { - service_share_connect = true; - connections = 1; - } - - RSocket[] clients = new RSocket[connections]; - for (int i = 0; i < clients.length; i++) { - if (service_share_connect) { - clients[i] = getSharedClient(url); - } else { - clients[i] = initClient(url); - } - } - return clients; - } - - /** - * Get shared connection - */ - private RSocket getSharedClient(URL url) { - String key = url.getAddress(); - RSocket client = referenceClientMap.get(key); - if (client != null) { - return client; - } - - locks.putIfAbsent(key, new Object()); - synchronized (locks.get(key)) { - if (referenceClientMap.containsKey(key)) { - return referenceClientMap.get(key); - } - - client = initClient(url); - referenceClientMap.put(key, client); - locks.remove(key); - return client; - } - } - - /** - * Create new connection - */ - private RSocket initClient(URL url) { - try { - InetSocketAddress serverAddress = new InetSocketAddress(NetUtils.filterLocalHost(url.getHost()), url.getPort()); - RSocket client = RSocketFactory.connect().keepAliveTickPeriod(Duration.ZERO).keepAliveAckTimeout(Duration.ZERO).acceptor( - rSocket -> - new AbstractRSocket() { - public Mono requestResponse(Payload payload) { - //TODO support Mono arg - throw new UnsupportedOperationException(); - } - - @Override - public Flux requestStream(Payload payload) { - //TODO support Flux arg - throw new UnsupportedOperationException(); - } - }) - .transport(TcpClientTransport.create(serverAddress)) - .start() - .block(); - return client; - } catch (Throwable e) { - throw new RpcException("Fail to create remoting client for service(" + url + "): " + e.getMessage(), e); - } - - } - - @Override - public void destroy() { - for (String key : new ArrayList(serverMap.keySet())) { - CloseableChannel server = serverMap.remove(key); - if (server != null) { - try { - if (logger.isInfoEnabled()) { - logger.info("Close dubbo server: " + server.address()); - } - server.dispose(); - } catch (Throwable t) { - logger.warn(t.getMessage(), t); - } - } - } - - for (String key : new ArrayList(referenceClientMap.keySet())) { - RSocket client = referenceClientMap.remove(key); - if (client != null) { - try { -// if (logger.isInfoEnabled()) { -// logger.info("Close dubbo connect: " + client. + "-->" + client.getRemoteAddress()); -// } - client.dispose(); - } catch (Throwable t) { - logger.warn(t.getMessage(), t); - } - } - } - super.destroy(); - } - - - //server process logic - private class SocketAcceptorImpl implements SocketAcceptor { - - private final int port; - - public SocketAcceptorImpl(int port) { - this.port = port; - } - - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - public Mono requestResponse(Payload payload) { - try { - Map metadata = decodeMetadata(payload); - Byte serializeId = ((Integer) metadata.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); - Invocation inv = decodeInvocation(payload, metadata, serializeId); - - Result result = inv.getInvoker().invoke(inv); - - Class retType = RpcUtils.getReturnType(inv); - //ok - if (retType != null && Mono.class.isAssignableFrom(retType)) { - Throwable th = result.getException(); - if (th == null) { - Mono bizMono = (Mono) result.getValue(); - Mono retMono = bizMono.map(new Function() { - @Override - public Payload apply(Object o) { - try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - out.writeByte((byte) 0); - out.writeObject(o); - out.flushBuffer(); - bos.flush(); - bos.close(); - Payload responsePayload = DefaultPayload.create(bos.toByteArray()); - return responsePayload; - } catch (Throwable t) { - throw Exceptions.propagate(t); - } - } - }).onErrorResume(new Function>() { - @Override - public Publisher apply(Throwable throwable) { - try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - out.writeByte((byte) RSocketConstants.FLAG_ERROR); - out.writeObject(throwable); - out.flushBuffer(); - bos.flush(); - bos.close(); - Payload errorPayload = DefaultPayload.create(bos.toByteArray()); - return Flux.just(errorPayload); - } catch (Throwable t) { - throw Exceptions.propagate(t); - } - } - }); - - return retMono; - } else { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - out.writeByte((byte) RSocketConstants.FLAG_ERROR); - out.writeObject(th); - out.flushBuffer(); - bos.flush(); - bos.close(); - Payload errorPayload = DefaultPayload.create(bos.toByteArray()); - return Mono.just(errorPayload); - } - - } else { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - int flag = RSocketConstants.FLAG_HAS_ATTACHMENT; - - Throwable th = result.getException(); - if (th == null) { - Object ret = result.getValue(); - if (ret == null) { - flag |= RSocketConstants.FLAG_NULL_VALUE; - out.writeByte((byte) flag); - } else { - out.writeByte((byte) flag); - out.writeObject(ret); - } - } else { - flag |= RSocketConstants.FLAG_ERROR; - out.writeByte((byte) flag); - out.writeObject(th); - } - out.writeObject(result.getAttachments()); - out.flushBuffer(); - bos.flush(); - bos.close(); - - Payload responsePayload = DefaultPayload.create(bos.toByteArray()); - return Mono.just(responsePayload); - } - } catch (Throwable t) { - //application error - return Mono.error(t); - } finally { - payload.release(); - } - } - - public Flux requestStream(Payload payload) { - try { - Map metadata = decodeMetadata(payload); - Byte serializeId = ((Integer) metadata.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); - Invocation inv = decodeInvocation(payload, metadata, serializeId); - - Result result = inv.getInvoker().invoke(inv); - //Class retType = RpcUtils.getReturnType(inv); - - Throwable th = result.getException(); - if (th != null) { - Payload errorPayload = encodeError(th, serializeId); - return Flux.just(errorPayload); - } - - Flux flux = (Flux) result.getValue(); - Flux retFlux = flux.map(new Function() { - @Override - public Payload apply(Object o) { - try { - return encodeData(o, serializeId); - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }).onErrorResume(new Function>() { - @Override - public Publisher apply(Throwable throwable) { - try { - Payload errorPayload = encodeError(throwable, serializeId); - return Flux.just(errorPayload); - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }); - return retFlux; - } catch (Throwable t) { - return Flux.error(t); - } finally { - payload.release(); - } - } - - private Payload encodeData(Object data, byte serializeId) throws Throwable { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - out.writeByte((byte) 0); - out.writeObject(data); - out.flushBuffer(); - bos.flush(); - bos.close(); - return DefaultPayload.create(bos.toByteArray()); - } - - private Payload encodeError(Throwable throwable, byte serializeId) throws Throwable { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); - out.writeByte((byte) RSocketConstants.FLAG_ERROR); - out.writeObject(throwable); - out.flushBuffer(); - bos.flush(); - bos.close(); - return DefaultPayload.create(bos.toByteArray()); - } - - private Map decodeMetadata(Payload payload) throws IOException { - ByteBuffer metadataBuffer = payload.getMetadata(); - byte[] metadataBytes = new byte[metadataBuffer.remaining()]; - metadataBuffer.get(metadataBytes, metadataBuffer.position(), metadataBuffer.remaining()); - return MetadataCodec.decodeMetadata(metadataBytes); - } - - private Invocation decodeInvocation(Payload payload, Map metadata, Byte serializeId) throws RemotingException, IOException, ClassNotFoundException { - Invoker invoker = getInvoker(port, metadata); - - String serviceName = (String) metadata.get(RSocketConstants.SERVICE_NAME_KEY); - String version = (String) metadata.get(RSocketConstants.SERVICE_VERSION_KEY); - String methodName = (String) metadata.get(RSocketConstants.METHOD_NAME_KEY); - String paramType = (String) metadata.get(RSocketConstants.PARAM_TYPE_KEY); - - ByteBuffer dataBuffer = payload.getData(); - byte[] dataBytes = new byte[dataBuffer.remaining()]; - dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); - - - //TODO how to get remote address - //RpcContext rpcContext = RpcContext.getContext(); - //rpcContext.setRemoteAddress(channel.getRemoteAddress()); - - - RpcInvocation inv = new RpcInvocation(); - inv.setInvoker(invoker); - inv.setAttachment(Constants.PATH_KEY, serviceName); - inv.setAttachment(Constants.VERSION_KEY, version); - inv.setMethodName(methodName); - - - InputStream dataInputStream = new ByteArrayInputStream(dataBytes); - ObjectInput in = CodecSupport.getSerializationById(serializeId).deserialize(null, dataInputStream); - - Object[] args; - Class[] pts; - String desc = paramType; - if (desc.length() == 0) { - pts = new Class[0]; - args = new Object[0]; - } else { - pts = ReflectUtils.desc2classArray(desc); - args = new Object[pts.length]; - for (int i = 0; i < args.length; i++) { - try { - args[i] = in.readObject(pts[i]); - } catch (Exception e) { - if (log.isWarnEnabled()) { - log.warn("Decode argument failed: " + e.getMessage(), e); - } - } - } - } - inv.setParameterTypes(pts); - inv.setArguments(args); - Map map = (Map) in.readObject(Map.class); - if (map != null && map.size() > 0) { - inv.addAttachments(map); - } - return inv; - } - }); - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.dubbo.rpc.protocol.rsocket; + +import io.rsocket.AbstractRSocket; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import org.apache.dubbo.common.Constants; +import org.apache.dubbo.common.URL; +import org.apache.dubbo.common.extension.ExtensionLoader; +import org.apache.dubbo.common.logger.Logger; +import org.apache.dubbo.common.logger.LoggerFactory; +import org.apache.dubbo.common.serialize.ObjectInput; +import org.apache.dubbo.common.serialize.ObjectOutput; +import org.apache.dubbo.common.serialize.Serialization; +import org.apache.dubbo.common.utils.NetUtils; +import org.apache.dubbo.common.utils.ReflectUtils; +import org.apache.dubbo.remoting.RemotingException; +import org.apache.dubbo.remoting.transport.CodecSupport; +import org.apache.dubbo.rpc.Exporter; +import org.apache.dubbo.rpc.Invocation; +import org.apache.dubbo.rpc.Invoker; +import org.apache.dubbo.rpc.Protocol; +import org.apache.dubbo.rpc.Result; +import org.apache.dubbo.rpc.RpcException; +import org.apache.dubbo.rpc.RpcInvocation; +import org.apache.dubbo.rpc.protocol.AbstractProtocol; +import org.apache.dubbo.rpc.support.RpcUtils; +import org.reactivestreams.Publisher; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; + +public class RSocketProtocol extends AbstractProtocol { + + public static final String NAME = "rsocket"; + public static final int DEFAULT_PORT = 30880; + private static final Logger log = LoggerFactory.getLogger(RSocketProtocol.class); + private static RSocketProtocol INSTANCE; + + // + private final Map serverMap = new ConcurrentHashMap(); + + // + private final Map referenceClientMap = new ConcurrentHashMap(); + + private final ConcurrentMap locks = new ConcurrentHashMap(); + + public RSocketProtocol() { + INSTANCE = this; + } + + public static RSocketProtocol getRSocketProtocol() { + if (INSTANCE == null) { + ExtensionLoader.getExtensionLoader(Protocol.class).getExtension(RSocketProtocol.NAME); // load + } + return INSTANCE; + } + + public Collection> getExporters() { + return Collections.unmodifiableCollection(exporterMap.values()); + } + + Map> getExporterMap() { + return exporterMap; + } + + Invoker getInvoker(int port, Map metadataMap) throws RemotingException { + String path = (String) metadataMap.get(RSocketConstants.SERVICE_NAME_KEY); + String serviceKey = serviceKey(port, path, (String) metadataMap.get(RSocketConstants.SERVICE_VERSION_KEY), (String) metadataMap.get(Constants.GROUP_KEY)); + RSocketExporter exporter = (RSocketExporter) exporterMap.get(serviceKey); + if (exporter == null) { + //throw new Throwable("Not found exported service: " + serviceKey + " in " + exporterMap.keySet() + ", may be version or group mismatch " + ", channel: consumer: " + channel.getRemoteAddress() + " --> provider: " + channel.getLocalAddress() + ", message:" + inv); + throw new RuntimeException("Not found exported service: " + serviceKey + " in " + exporterMap.keySet() + ", may be version or group mismatch "); + } + + return exporter.getInvoker(); + } + + public Collection> getInvokers() { + return Collections.unmodifiableCollection(invokers); + } + + @Override + public int getDefaultPort() { + return DEFAULT_PORT; + } + + @Override + public Exporter export(Invoker invoker) throws RpcException { + URL url = invoker.getUrl(); + + // export service. + String key = serviceKey(url); + RSocketExporter exporter = new RSocketExporter(invoker, key, exporterMap); + exporterMap.put(key, exporter); + + openServer(url); + return exporter; + } + + private void openServer(URL url) { + String key = url.getAddress(); + //client can export a service which's only for server to invoke + boolean isServer = url.getParameter(Constants.IS_SERVER_KEY, true); + if (isServer) { + CloseableChannel server = serverMap.get(key); + if (server == null) { + synchronized (this) { + server = serverMap.get(key); + if (server == null) { + serverMap.put(key, createServer(url)); + } + } + } + } + } + + private CloseableChannel createServer(URL url) { + try { + String bindIp = url.getParameter(Constants.BIND_IP_KEY, url.getHost()); + int bindPort = url.getParameter(Constants.BIND_PORT_KEY, url.getPort()); + if (url.getParameter(Constants.ANYHOST_KEY, false) || NetUtils.isInvalidLocalHost(bindIp)) { + bindIp = NetUtils.ANYHOST; + } + return RSocketFactory.receive() + .acceptor(new SocketAcceptorImpl(bindPort)) + .transport(TcpServerTransport.create(bindIp, bindPort)) + .start() + .block(); + } catch (Throwable e) { + throw new RpcException("Fail to start server(url: " + url + ") " + e.getMessage(), e); + } + } + + + @Override + public Invoker refer(Class serviceType, URL url) throws RpcException { + // create rpc invoker. + RSocketInvoker invoker = new RSocketInvoker(serviceType, url, getClients(url), invokers); + invokers.add(invoker); + return invoker; + } + + private RSocket[] getClients(URL url) { + // whether to share connection + boolean service_share_connect = false; + int connections = url.getParameter(Constants.CONNECTIONS_KEY, 0); + // if not configured, connection is shared, otherwise, one connection for one service + if (connections == 0) { + service_share_connect = true; + connections = 1; + } + + RSocket[] clients = new RSocket[connections]; + for (int i = 0; i < clients.length; i++) { + if (service_share_connect) { + clients[i] = getSharedClient(url); + } else { + clients[i] = initClient(url); + } + } + return clients; + } + + /** + * Get shared connection + */ + private RSocket getSharedClient(URL url) { + String key = url.getAddress(); + RSocket client = referenceClientMap.get(key); + if (client != null) { + return client; + } + + locks.putIfAbsent(key, new Object()); + synchronized (locks.get(key)) { + if (referenceClientMap.containsKey(key)) { + return referenceClientMap.get(key); + } + + client = initClient(url); + referenceClientMap.put(key, client); + locks.remove(key); + return client; + } + } + + /** + * Create new connection + */ + private RSocket initClient(URL url) { + try { + InetSocketAddress serverAddress = new InetSocketAddress(NetUtils.filterLocalHost(url.getHost()), url.getPort()); + RSocket client = RSocketFactory.connect().keepAliveTickPeriod(Duration.ZERO).keepAliveAckTimeout(Duration.ZERO).acceptor( + rSocket -> + new AbstractRSocket() { + public Mono requestResponse(Payload payload) { + try { + ByteBuffer metadataBuffer = payload.getMetadata(); + byte[] metadataBytes = new byte[metadataBuffer.remaining()]; + metadataBuffer.get(metadataBytes, metadataBuffer.position(), metadataBuffer.remaining()); + Map metadataMap = MetadataCodec.decodeMetadata(metadataBytes); + Byte serializeId = ((Integer) metadataMap.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); + + + ByteBuffer dataBuffer = payload.getData(); + byte[] dataBytes = new byte[dataBuffer.remaining()]; + dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); + InputStream dataInputStream = new ByteArrayInputStream(dataBytes); + ObjectInput in = CodecSupport.getSerializationById(serializeId).deserialize(null, dataInputStream); + long id = in.readLong(); + + Mono mono = ResourceDirectory.unmountMono(id); + return mono.map(new Function() { + @Override + public Payload apply(Object o) { + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) 0); + out.writeObject(o); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload responsePayload = DefaultPayload.create(bos.toByteArray()); + return responsePayload; + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + }).onErrorResume(new Function>() { + @Override + public Publisher apply(Throwable throwable) { + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) RSocketConstants.FLAG_ERROR); + out.writeObject(throwable); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload errorPayload = DefaultPayload.create(bos.toByteArray()); + return Flux.just(errorPayload); + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + }); + + }catch (Throwable t){ + throw new RuntimeException(t); + } + } + + @Override + public Flux requestStream(Payload payload) { + try { + ByteBuffer metadataBuffer = payload.getMetadata(); + byte[] metadataBytes = new byte[metadataBuffer.remaining()]; + metadataBuffer.get(metadataBytes, metadataBuffer.position(), metadataBuffer.remaining()); + Map metadataMap = MetadataCodec.decodeMetadata(metadataBytes); + Byte serializeId = ((Integer) metadataMap.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); + + + ByteBuffer dataBuffer = payload.getData(); + byte[] dataBytes = new byte[dataBuffer.remaining()]; + dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); + InputStream dataInputStream = new ByteArrayInputStream(dataBytes); + ObjectInput in = CodecSupport.getSerializationById(serializeId).deserialize(null, dataInputStream); + long id = in.readLong(); + + Flux flux = ResourceDirectory.unmountFlux(id); + return flux.map(new Function() { + @Override + public Payload apply(Object o) { + try { + return encodeData(o, serializeId); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }).onErrorResume(new Function>() { + @Override + public Publisher apply(Throwable throwable) { + try { + Payload errorPayload = encodeError(throwable, serializeId); + return Flux.just(errorPayload); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }); + }catch (Throwable t){ + throw new RuntimeException(t); + } + } + + private Payload encodeData(Object data, byte serializeId) throws Throwable { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) 0); + out.writeObject(data); + out.flushBuffer(); + bos.flush(); + bos.close(); + return DefaultPayload.create(bos.toByteArray()); + } + + private Payload encodeError(Throwable throwable, byte serializeId) throws Throwable { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) RSocketConstants.FLAG_ERROR); + out.writeObject(throwable); + out.flushBuffer(); + bos.flush(); + bos.close(); + return DefaultPayload.create(bos.toByteArray()); + } + + }) + .transport(TcpClientTransport.create(serverAddress)) + .start() + .block(); + return client; + } catch (Throwable e) { + throw new RpcException("Fail to create remoting client for service(" + url + "): " + e.getMessage(), e); + } + + } + + @Override + public void destroy() { + for (String key : new ArrayList(serverMap.keySet())) { + CloseableChannel server = serverMap.remove(key); + if (server != null) { + try { + if (logger.isInfoEnabled()) { + logger.info("Close dubbo server: " + server.address()); + } + server.dispose(); + } catch (Throwable t) { + logger.warn(t.getMessage(), t); + } + } + } + + for (String key : new ArrayList(referenceClientMap.keySet())) { + RSocket client = referenceClientMap.remove(key); + if (client != null) { + try { +// if (logger.isInfoEnabled()) { +// logger.info("Close dubbo connect: " + client. + "-->" + client.getRemoteAddress()); +// } + client.dispose(); + } catch (Throwable t) { + logger.warn(t.getMessage(), t); + } + } + } + super.destroy(); + } + + + //server process logic + private class SocketAcceptorImpl implements SocketAcceptor { + + private final int port; + + public SocketAcceptorImpl(int port) { + this.port = port; + } + + @Override + public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { + return Mono.just( + new AbstractRSocket() { + public Mono requestResponse(Payload payload) { + try { + Map metadata = decodeMetadata(payload); + Byte serializeId = ((Integer) metadata.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); + Invocation inv = decodeInvocation(payload, metadata, serializeId); + + Result result = inv.getInvoker().invoke(inv); + + Class retType = RpcUtils.getReturnType(inv); + //ok + if (retType != null && Mono.class.isAssignableFrom(retType)) { + Throwable th = result.getException(); + if (th == null) { + Mono bizMono = (Mono) result.getValue(); + Mono retMono = bizMono.map(new Function() { + @Override + public Payload apply(Object o) { + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) 0); + out.writeObject(o); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload responsePayload = DefaultPayload.create(bos.toByteArray()); + return responsePayload; + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + }).onErrorResume(new Function>() { + @Override + public Publisher apply(Throwable throwable) { + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) RSocketConstants.FLAG_ERROR); + out.writeObject(throwable); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload errorPayload = DefaultPayload.create(bos.toByteArray()); + return Flux.just(errorPayload); + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + }); + + return retMono; + } else { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) RSocketConstants.FLAG_ERROR); + out.writeObject(th); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload errorPayload = DefaultPayload.create(bos.toByteArray()); + return Mono.just(errorPayload); + } + + } else { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + int flag = RSocketConstants.FLAG_HAS_ATTACHMENT; + + Throwable th = result.getException(); + if (th == null) { + Object ret = result.getValue(); + if (ret == null) { + flag |= RSocketConstants.FLAG_NULL_VALUE; + out.writeByte((byte) flag); + } else { + out.writeByte((byte) flag); + out.writeObject(ret); + } + } else { + flag |= RSocketConstants.FLAG_ERROR; + out.writeByte((byte) flag); + out.writeObject(th); + } + out.writeObject(result.getAttachments()); + out.flushBuffer(); + bos.flush(); + bos.close(); + + Payload responsePayload = DefaultPayload.create(bos.toByteArray()); + return Mono.just(responsePayload); + } + } catch (Throwable t) { + //application error + return Mono.error(t); + } finally { + payload.release(); + } + } + + public Flux requestStream(Payload payload) { + try { + Map metadata = decodeMetadata(payload); + Byte serializeId = ((Integer) metadata.get(RSocketConstants.SERIALIZE_TYPE_KEY)).byteValue(); + Invocation inv = decodeInvocation(payload, metadata, serializeId); + + Result result = inv.getInvoker().invoke(inv); + //Class retType = RpcUtils.getReturnType(inv); + + Throwable th = result.getException(); + if (th != null) { + Payload errorPayload = encodeError(th, serializeId); + return Flux.just(errorPayload); + } + + Flux flux = (Flux) result.getValue(); + Flux retFlux = flux.map(new Function() { + @Override + public Payload apply(Object o) { + try { + return encodeData(o, serializeId); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }).onErrorResume(new Function>() { + @Override + public Publisher apply(Throwable throwable) { + try { + Payload errorPayload = encodeError(throwable, serializeId); + return Flux.just(errorPayload); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }); + return retFlux; + } catch (Throwable t) { + return Flux.error(t); + } finally { + payload.release(); + } + } + + private Payload encodeData(Object data, byte serializeId) throws Throwable { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) 0); + out.writeObject(data); + out.flushBuffer(); + bos.flush(); + bos.close(); + return DefaultPayload.create(bos.toByteArray()); + } + + private Payload encodeError(Throwable throwable, byte serializeId) throws Throwable { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeByte((byte) RSocketConstants.FLAG_ERROR); + out.writeObject(throwable); + out.flushBuffer(); + bos.flush(); + bos.close(); + return DefaultPayload.create(bos.toByteArray()); + } + + private Map decodeMetadata(Payload payload) throws IOException { + ByteBuffer metadataBuffer = payload.getMetadata(); + byte[] metadataBytes = new byte[metadataBuffer.remaining()]; + metadataBuffer.get(metadataBytes, metadataBuffer.position(), metadataBuffer.remaining()); + return MetadataCodec.decodeMetadata(metadataBytes); + } + + private Invocation decodeInvocation(Payload payload, Map metadata, Byte serializeId) throws RemotingException, IOException, ClassNotFoundException { + Invoker invoker = getInvoker(port, metadata); + + String serviceName = (String) metadata.get(RSocketConstants.SERVICE_NAME_KEY); + String version = (String) metadata.get(RSocketConstants.SERVICE_VERSION_KEY); + String methodName = (String) metadata.get(RSocketConstants.METHOD_NAME_KEY); + String paramType = (String) metadata.get(RSocketConstants.PARAM_TYPE_KEY); + + ByteBuffer dataBuffer = payload.getData(); + byte[] dataBytes = new byte[dataBuffer.remaining()]; + dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); + + + //TODO how to get remote address + //RpcContext rpcContext = RpcContext.getContext(); + //rpcContext.setRemoteAddress(channel.getRemoteAddress()); + + + RpcInvocation inv = new RpcInvocation(); + inv.setInvoker(invoker); + inv.setAttachment(Constants.PATH_KEY, serviceName); + inv.setAttachment(Constants.VERSION_KEY, version); + inv.setMethodName(methodName); + + + InputStream dataInputStream = new ByteArrayInputStream(dataBytes); + ObjectInput in = CodecSupport.getSerializationById(serializeId).deserialize(null, dataInputStream); + + Object[] args; + Class[] pts; + String desc = paramType; + if (desc.length() == 0) { + pts = new Class[0]; + args = new Object[0]; + } else { + pts = ReflectUtils.desc2classArray(desc); + args = new Object[pts.length]; + for (int i = 0; i < args.length; i++) { + try { + args[i] = in.readObject(pts[i]); + } catch (Exception e) { + if (log.isWarnEnabled()) { + log.warn("Decode argument failed: " + e.getMessage(), e); + } + } + } + } + + //process stream args + for (int i = 0; i < pts.length; i++) { + if (ResourceInfo.class.isAssignableFrom(pts[i])) { + ResourceInfo resourceInfo = (ResourceInfo) args[i]; + if (resourceInfo.getType() == ResourceInfo.RESOURCE_TYPE_MONO) { + pts[i] = Mono.class; + args[i] = getMonoProxy(resourceInfo.getId(), serializeId, reactiveSocket); + } else { + pts[i] = Flux.class; + args[i] = getFluxProxy(resourceInfo.getId(), serializeId, reactiveSocket); + } + } + } + + inv.setParameterTypes(pts); + inv.setArguments(args); + Map map = (Map) in.readObject(Map.class); + if (map != null && map.size() > 0) { + inv.addAttachments(map); + } + return inv; + } + }); + } + + private Mono getMonoProxy(long id, Byte serializeId, RSocket rSocket) throws IOException { + Map metadataMap = new HashMap(); + metadataMap.put(RSocketConstants.SERIALIZE_TYPE_KEY, serializeId); + byte[] metadata = MetadataCodec.encodeMetadata(metadataMap); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeLong(id); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload payload = DefaultPayload.create(bos.toByteArray(), metadata); + + Mono payloads = rSocket.requestResponse(payload); + Mono streamArg = payloads.map(new Function() { + @Override + public Object apply(Payload payload) { + return decodeData(serializeId, payload); + } + }); + return streamArg; + } + + private Flux getFluxProxy(long id, Byte serializeId, RSocket rSocket) throws IOException { + Map metadataMap = new HashMap(); + metadataMap.put(RSocketConstants.SERIALIZE_TYPE_KEY, serializeId); + byte[] metadata = MetadataCodec.encodeMetadata(metadataMap); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = CodecSupport.getSerializationById(serializeId).serialize(null, bos); + out.writeLong(id); + out.flushBuffer(); + bos.flush(); + bos.close(); + Payload payload = DefaultPayload.create(bos.toByteArray(), metadata); + + Flux payloads = rSocket.requestStream(payload); + Flux streamArg = payloads.map(new Function() { + @Override + public Object apply(Payload payload) { + return decodeData(serializeId, payload); + } + }); + return streamArg; + } + + private Object decodeData(Byte serializeId, Payload payload) { + try { + Serialization serialization = CodecSupport.getSerializationById(serializeId); + //TODO save the copy + ByteBuffer dataBuffer = payload.getData(); + byte[] dataBytes = new byte[dataBuffer.remaining()]; + dataBuffer.get(dataBytes, dataBuffer.position(), dataBuffer.remaining()); + InputStream dataInputStream = new ByteArrayInputStream(dataBytes); + ObjectInput in = serialization.deserialize(null, dataInputStream); + int flag = in.readByte(); + if ((flag & RSocketConstants.FLAG_ERROR) != 0) { + Throwable t = (Throwable) in.readObject(); + throw t; + } else { + return in.readObject(); + } + } catch (Throwable t) { + throw Exceptions.propagate(t); + } + } + + } +} diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceDirectory.java b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceDirectory.java new file mode 100644 index 00000000000..c1b66ac88db --- /dev/null +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceDirectory.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.dubbo.rpc.protocol.rsocket; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +public class ResourceDirectory { + + private static AtomicLong idGen = new AtomicLong(1); + + private static ConcurrentHashMap id2ResourceMap = new ConcurrentHashMap(); + + + public static long mountResource(Object resource) { + long id = idGen.getAndIncrement(); + id2ResourceMap.put(id, resource); + return id; + } + + public static Object unmountResource(long id) { + return id2ResourceMap.get(id); + } + + public static long mountMono(Mono mono) { + long id = idGen.getAndIncrement(); + id2ResourceMap.put(id, mono); + return id; + } + + public static long mountFlux(Flux flux) { + long id = idGen.getAndIncrement(); + id2ResourceMap.put(id, flux); + return id; + } + + public static Mono unmountMono(long id) { + return (Mono) id2ResourceMap.get(id); + } + + public static Flux unmountFlux(long id) { + return (Flux) id2ResourceMap.get(id); + } + +} diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceInfo.java b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceInfo.java new file mode 100644 index 00000000000..1c1275bd38a --- /dev/null +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/main/java/org/apache/dubbo/rpc/protocol/rsocket/ResourceInfo.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.dubbo.rpc.protocol.rsocket; + +import java.io.Serializable; + +public class ResourceInfo implements Serializable { + + public static final byte RESOURCE_TYPE_MONO = 1; + public static final byte RESOURCE_TYPE_FLUX = 2; + + private final long id; + private final byte type; + + public ResourceInfo(long id, byte type) { + this.id = id; + this.type = type; + } + + public long getId() { + return id; + } + + public byte getType() { + return type; + } +} diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocolTest.java b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocolTest.java index e34a6f76e13..d9733eb4676 100644 --- a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocolTest.java +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/protocol/rsocket/RSocketProtocolTest.java @@ -117,6 +117,7 @@ public void testDubboProtocolMultiService() throws Exception { assertEquals("world", service.echo("world")); assertEquals("hello world", remote.sayHello("world")); + EchoService serviceEcho = (EchoService) service; assertEquals(serviceEcho.$echo("test"), "test"); @@ -216,4 +217,43 @@ public void accept(String s) { } } + + @Test + public void testRequestMonoWithMonoArg() throws Exception { + DemoService service = new DemoServiceImpl(); + protocol.export(proxy.getInvoker(service, DemoService.class, URL.valueOf("rsocket://127.0.0.1:9020/" + DemoService.class.getName()))); + service = proxy.getProxy(protocol.refer(DemoService.class, URL.valueOf("rsocket://127.0.0.1:9020/" + DemoService.class.getName()).addParameter("timeout", 3000l))); + + Mono result = service.requestMonoWithMonoArg(Mono.just("A"), Mono.just("B")); + result.doOnNext(new Consumer() { + @Override + public void accept(String s) { + assertEquals(s, "A B"); + System.out.println(s); + } + }).block(); + } + + + @Test + public void testRequestFluxWithFluxArg() throws Exception { + DemoService service = new DemoServiceImpl(); + protocol.export(proxy.getInvoker(service, DemoService.class, URL.valueOf("rsocket://127.0.0.1:9020/" + DemoService.class.getName()))); + service = proxy.getProxy(protocol.refer(DemoService.class, URL.valueOf("rsocket://127.0.0.1:9020/" + DemoService.class.getName()).addParameter("timeout", 3000l))); + + { + Flux result = service.requestFluxWithFluxArg(Flux.just("A","B","C"), Flux.just("1","2","3")); + result.doOnNext(new Consumer() { + @Override + public void accept(String s) { + System.out.println(s); + } + }).takeLast(1).doOnNext(new Consumer() { + @Override + public void accept(String s) { + assertEquals(s, "C 3"); + } + }).blockLast(); + } + } } diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoService.java b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoService.java index b2b37b4c1d5..8f7a3ddc5ad 100644 --- a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoService.java +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoService.java @@ -61,5 +61,9 @@ public interface DemoService { Flux requestFluxBizError(String name); + Mono requestMonoWithMonoArg(Mono m1, Mono m2); + + Flux requestFluxWithFluxArg(Flux f1, Flux f2); + } diff --git a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoServiceImpl.java b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoServiceImpl.java index b67e3e09be4..1ba6d39a7b0 100644 --- a/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoServiceImpl.java +++ b/dubbo-rpc/dubbo-rpc-rsocket/src/test/java/org/apache/dubbo/rpc/service/DemoServiceImpl.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Consumer; public class DemoServiceImpl implements DemoService { @@ -149,5 +150,25 @@ public void accept(FluxSink fluxSink) { }); } + @Override + public Mono requestMonoWithMonoArg(Mono m1, Mono m2) { + return m1.zipWith(m2, new BiFunction() { + @Override + public String apply(String s, String s2) { + return s+" "+s2; + } + }); + } + + @Override + public Flux requestFluxWithFluxArg(Flux f1, Flux f2) { + return f1.zipWith(f2, new BiFunction() { + @Override + public String apply(String s, String s2) { + return s+" "+s2; + } + }); + } + }