Skip to content

Commit

Permalink
Add serialize id check for 2.6 (#7912)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbumenJ committed May 29, 2021
1 parent 3a9172f commit 077da34
Show file tree
Hide file tree
Showing 17 changed files with 286 additions and 37 deletions.
Expand Up @@ -664,4 +664,13 @@ public class Constants {
public static final String ENABLE_NATIVE_JAVA_GENERIC_SERIALIZE = "dubbo.security.serialize.generic.native-java-enable";

public static final String SERIALIZE_BLOCKED_LIST_FILE_PATH = "security/serialize.blockedlist";

public static final String DEFAULT_VERSION = "0.0.0";

public static final String SERIALIZATION_SECURITY_CHECK_KEY = "serialization.security.check";

public static final String SERIALIZATION_ID_KEY = "serialization_id";

public static final String INVOCATION_KEY = "invocation";

}
Expand Up @@ -30,6 +30,7 @@
import com.alibaba.dubbo.config.model.ApplicationModel;
import com.alibaba.dubbo.config.model.ProviderModel;
import com.alibaba.dubbo.config.support.Parameter;
import com.alibaba.dubbo.remoting.transport.CodecSupport;
import com.alibaba.dubbo.rpc.Exporter;
import com.alibaba.dubbo.rpc.Invoker;
import com.alibaba.dubbo.rpc.Protocol;
Expand Down Expand Up @@ -317,6 +318,7 @@ protected synchronized void doExport() {
path = interfaceName;
}
doExportUrls();
CodecSupport.addProviderSupportedSerialization(getUniqueServiceName(), getExportedUrls());
ProviderModel providerModel = new ProviderModel(getUniqueServiceName(), this, ref);
ApplicationModel.initProviderModel(getUniqueServiceName(), providerModel);
}
Expand Down
Expand Up @@ -16,6 +16,9 @@
*/
package com.alibaba.dubbo.remoting.exchange;

import java.util.HashMap;
import java.util.Map;

/**
* Response
*/
Expand Down Expand Up @@ -92,6 +95,8 @@ public class Response {

private Object mResult;

private Map<String, Object> attributes = new HashMap<String, Object>(2);

public Response() {
}

Expand Down Expand Up @@ -164,6 +169,14 @@ public void setErrorMessage(String msg) {
mErrorMsg = msg;
}

public Object getAttribute(String key) {
return attributes.get(key);
}

public void setAttribute(String key, Object value) {
attributes.put(key, value);
}

@Override
public String toString() {
return "Response [id=" + mId + ", version=" + mVersion + ", status=" + mStatus + ", event=" + mEvent
Expand Down
Expand Up @@ -38,6 +38,7 @@
import com.alibaba.dubbo.remoting.transport.CodecSupport;
import com.alibaba.dubbo.remoting.transport.ExceedPayloadLimitException;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

Expand Down Expand Up @@ -155,9 +156,12 @@ protected Object decodeBody(Channel channel, InputStream is, byte[] header) thro
if (status == Response.OK) {
Object data;
if (res.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeHeartbeatData(channel, CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else if (res.isEvent()) {
data = decodeEventData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeEventData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else {
data = decodeResponseData(channel, in, getRequestData(id));
}
Expand All @@ -182,9 +186,13 @@ protected Object decodeBody(Channel channel, InputStream is, byte[] header) thro
ObjectInput in = CodecSupport.deserialize(channel.getUrl(), is, proto);
Object data;
if (req.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeHeartbeatData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else if (req.isEvent()) {
data = decodeEventData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeEventData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else {
data = decodeRequestData(channel, in);
}
Expand Down Expand Up @@ -340,15 +348,6 @@ protected Object decodeData(ObjectInput in) throws IOException {
return decodeRequestData(in);
}

@Deprecated
protected Object decodeHeartbeatData(ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}

protected Object decodeRequestData(ObjectInput in) throws IOException {
try {
return in.readObject();
Expand Down Expand Up @@ -392,21 +391,22 @@ protected Object decodeData(Channel channel, ObjectInput in) throws IOException
return decodeRequestData(channel, in);
}

protected Object decodeEventData(Channel channel, ObjectInput in) throws IOException {
protected Object decodeEventData(Channel channel, ObjectInput in, byte[] eventPayload) throws IOException {
try {
int dataLen = eventPayload.length;
int threshold = Integer.parseInt(System.getProperty("deserialization.event.size", "50"));
if (dataLen > threshold) {
throw new IllegalArgumentException("Event data too long, actual size " + dataLen + ", threshold " + threshold + " rejected for security consideration.");
}
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}

@Deprecated
protected Object decodeHeartbeatData(Channel channel, ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
protected Object decodeHeartbeatData(Channel channel, ObjectInput in, byte[] eventPayload) throws IOException {
return decodeEventData(channel, in, eventPayload);
}

protected Object decodeRequestData(Channel channel, ObjectInput in) throws IOException {
Expand Down
Expand Up @@ -24,6 +24,8 @@
import com.alibaba.dubbo.common.utils.NetUtils;
import com.alibaba.dubbo.remoting.Channel;
import com.alibaba.dubbo.remoting.Codec2;
import com.alibaba.dubbo.remoting.exchange.Request;
import com.alibaba.dubbo.remoting.exchange.Response;

import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -51,6 +53,14 @@ protected Serialization getSerialization(Channel channel) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected Serialization getSerialization(Channel channel, Request req) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected Serialization getSerialization(Channel channel, Response res) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected boolean isClientSide(Channel channel) {
String side = (String) channel.getAttribute(Constants.SIDE_KEY);
if ("client".equals(side)) {
Expand Down
Expand Up @@ -24,18 +24,29 @@
import com.alibaba.dubbo.common.logger.LoggerFactory;
import com.alibaba.dubbo.common.serialize.ObjectInput;
import com.alibaba.dubbo.common.serialize.Serialization;
import com.alibaba.dubbo.common.utils.CollectionUtils;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static com.alibaba.dubbo.common.Constants.SERIALIZATION_KEY;

public class CodecSupport {

private static final Logger logger = LoggerFactory.getLogger(CodecSupport.class);
private static Map<Byte, Serialization> ID_SERIALIZATION_MAP = new HashMap<Byte, Serialization>();
private static Map<Byte, String> ID_SERIALIZATIONNAME_MAP = new HashMap<Byte, String>();
private static Map<String, Byte> SERIALIZATIONNAME_ID_MAP = new HashMap<String, Byte>();

private static Map<String, Set<Byte>> PROVIDER_SUPPORTED_SERIALIZATION = new ConcurrentHashMap<String, Set<Byte>>();

static {
Set<String> supportedExtensions = ExtensionLoader.getExtensionLoader(Serialization.class).getSupportedExtensions();
Expand All @@ -51,6 +62,7 @@ public class CodecSupport {
}
ID_SERIALIZATION_MAP.put(idByte, serialization);
ID_SERIALIZATIONNAME_MAP.put(idByte, name);
SERIALIZATIONNAME_ID_MAP.put(name, idByte);
}
}

Expand All @@ -63,23 +75,72 @@ public static Serialization getSerializationById(Byte id) {

public static Serialization getSerialization(URL url) {
return ExtensionLoader.getExtensionLoader(Serialization.class).getExtension(
url.getParameter(Constants.SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION));
url.getParameter(SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION));
}

public static Serialization getSerialization(URL url, Byte id) throws IOException {
Serialization serialization = getSerializationById(id);
String serializationName = url.getParameter(Constants.SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION);
// Check if "serialization id" passed from network matches the id on this side(only take effect for JDK serialization), for security purpose.
if (serialization == null
|| ((id == 3 || id == 7 || id == 4) && !(serializationName.equals(ID_SERIALIZATIONNAME_MAP.get(id))))) {
throw new IOException("Unexpected serialization id:" + id + " received from network, please check if the peer send the right id.");
Serialization result = getSerializationById(id);
if (result == null) {
throw new IOException("Unrecognized serialize type from consumer: " + id);
}
return serialization;
return result;
}

public static ObjectInput deserialize(URL url, InputStream is, byte proto) throws IOException {
Serialization s = getSerialization(url, proto);
return s.deserialize(url, is);
}

/**
* Read all payload to byte[]
*
* @param is
* @return
* @throws IOException
*/
public static byte[] getPayload(InputStream is) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
while ((len = is.read(buffer)) > -1) {
baos.write(buffer, 0, len);
}
baos.flush();
return baos.toByteArray();
}

public static Byte getIDByName(String name) {
return SERIALIZATIONNAME_ID_MAP.get(name);
}

public static void checkSerialization(String path, String version, Byte id) throws IOException {
Set<Byte> supportedSerialization = PROVIDER_SUPPORTED_SERIALIZATION.get(path + ":" + version);
if (Constants.DEFAULT_VERSION.equals(version) && CollectionUtils.isEmpty(supportedSerialization)) {
supportedSerialization = PROVIDER_SUPPORTED_SERIALIZATION.get(path);
}
if (CollectionUtils.isEmpty(supportedSerialization)) {
if (logger.isWarnEnabled()) {
logger.warn("Serialization security check is enabled but cannot work as expected because " +
"there's no matched provider model for path " + path + ", version " + version);
}
} else {
if (!supportedSerialization.contains(id)) {
throw new IOException("Unexpected serialization id:" + id + " received from network, please check if the peer send the right id.");
}
}
}

public static void addProviderSupportedSerialization(String serviceName, List<URL> exportedUrls) {
if (CollectionUtils.isNotEmpty(exportedUrls)) {
Set<Byte> supportedSerialization = new HashSet<Byte>();
for (URL url : exportedUrls) {
String serializationName = url.getParameter(SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION);
Byte localId = SERIALIZATIONNAME_ID_MAP.get(serializationName);
supportedSerialization.add(localId);
}
PROVIDER_SUPPORTED_SERIALIZATION.put(serviceName, Collections.unmodifiableSet(supportedSerialization));
}
}


}
Expand Up @@ -230,12 +230,14 @@ public void test_Decode_Return_Request_Event_Object() throws IOException {
Person person = new Person();
byte[] request = getRequestBytes(person, header);

System.setProperty("deserialization.event.size", "100");
Request obj = (Request) decode(request);
Assert.assertEquals(person, obj.getData());
Assert.assertEquals(true, obj.isTwoWay());
Assert.assertEquals(true, obj.isEvent());
Assert.assertEquals(Version.getProtocolVersion(), obj.getVersion());
System.out.println(obj);
System.clearProperty("deserialization.event.size");
}

@Test
Expand Down Expand Up @@ -269,7 +271,7 @@ public void test_Decode_Return_Request_Heartbeat_Object() throws IOException {
@Test
public void test_Decode_Return_Request_Object() throws IOException {
//|10011111|20-stats=ok|id=0|length=0
byte[] header = new byte[]{MAGIC_HIGH, MAGIC_LOW, (byte) 0xe2, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
byte[] header = new byte[]{MAGIC_HIGH, MAGIC_LOW, (byte) 0xc2, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Person person = new Person();
byte[] request = getRequestBytes(person, header);

Expand Down
5 changes: 5 additions & 0 deletions dubbo-rpc/dubbo-rpc-api/pom.xml
Expand Up @@ -39,5 +39,10 @@
<artifactId>dubbo-serialization-api</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dubbo-remoting-api</artifactId>
<version>${project.parent.version}</version>
</dependency>
</dependencies>
</project>
Expand Up @@ -83,4 +83,9 @@ public interface Invocation {
*/
Invoker<?> getInvoker();

Object put(Object key, Object value);

Object get(Object key);

Map<Object, Object> getAttributes();
}
Expand Up @@ -44,6 +44,8 @@ public class RpcInvocation implements Invocation, Serializable {

private transient Invoker<?> invoker;

private Map<Object, Object> attributes = new HashMap<Object, Object>(2);

public RpcInvocation() {
}

Expand Down Expand Up @@ -204,11 +206,26 @@ public String getAttachment(String key, String defaultValue) {
return value;
}

@Override
public Object put(Object key, Object value) {
return attributes.put(key, value);
}

@Override
public Object get(Object key) {
return attributes.get(key);
}

@Override
public Map<Object, Object> getAttributes() {
return attributes;
}

@Override
public String toString() {
return "RpcInvocation [methodName=" + methodName + ", parameterTypes="
+ Arrays.toString(parameterTypes) + ", arguments=" + Arrays.toString(arguments)
+ ", attachments=" + attachments + "]";
+ ", attachments=" + attachments + ", attributes=" + attributes + "]";
}

}

0 comments on commit 077da34

Please sign in to comment.