Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialize id check for 2.6 #7912

Merged
merged 5 commits into from May 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -664,4 +664,13 @@ public class Constants {
public static final String ENABLE_NATIVE_JAVA_GENERIC_SERIALIZE = "dubbo.security.serialize.generic.native-java-enable";

public static final String SERIALIZE_BLOCKED_LIST_FILE_PATH = "security/serialize.blockedlist";

public static final String DEFAULT_VERSION = "0.0.0";

public static final String SERIALIZATION_SECURITY_CHECK_KEY = "serialization.security.check";

public static final String SERIALIZATION_ID_KEY = "serialization_id";

public static final String INVOCATION_KEY = "invocation";

}
Expand Up @@ -30,6 +30,7 @@
import com.alibaba.dubbo.config.model.ApplicationModel;
import com.alibaba.dubbo.config.model.ProviderModel;
import com.alibaba.dubbo.config.support.Parameter;
import com.alibaba.dubbo.remoting.transport.CodecSupport;
import com.alibaba.dubbo.rpc.Exporter;
import com.alibaba.dubbo.rpc.Invoker;
import com.alibaba.dubbo.rpc.Protocol;
Expand Down Expand Up @@ -317,6 +318,7 @@ protected synchronized void doExport() {
path = interfaceName;
}
doExportUrls();
CodecSupport.addProviderSupportedSerialization(getUniqueServiceName(), getExportedUrls());
ProviderModel providerModel = new ProviderModel(getUniqueServiceName(), this, ref);
ApplicationModel.initProviderModel(getUniqueServiceName(), providerModel);
}
Expand Down
Expand Up @@ -16,6 +16,9 @@
*/
package com.alibaba.dubbo.remoting.exchange;

import java.util.HashMap;
import java.util.Map;

/**
* Response
*/
Expand Down Expand Up @@ -92,6 +95,8 @@ public class Response {

private Object mResult;

private Map<String, Object> attributes = new HashMap<String, Object>(2);

public Response() {
}

Expand Down Expand Up @@ -164,6 +169,14 @@ public void setErrorMessage(String msg) {
mErrorMsg = msg;
}

public Object getAttribute(String key) {
return attributes.get(key);
}

public void setAttribute(String key, Object value) {
attributes.put(key, value);
}

@Override
public String toString() {
return "Response [id=" + mId + ", version=" + mVersion + ", status=" + mStatus + ", event=" + mEvent
Expand Down
Expand Up @@ -38,6 +38,7 @@
import com.alibaba.dubbo.remoting.transport.CodecSupport;
import com.alibaba.dubbo.remoting.transport.ExceedPayloadLimitException;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

Expand Down Expand Up @@ -155,9 +156,12 @@ protected Object decodeBody(Channel channel, InputStream is, byte[] header) thro
if (status == Response.OK) {
Object data;
if (res.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeHeartbeatData(channel, CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else if (res.isEvent()) {
data = decodeEventData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeEventData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else {
data = decodeResponseData(channel, in, getRequestData(id));
}
Expand All @@ -182,9 +186,13 @@ protected Object decodeBody(Channel channel, InputStream is, byte[] header) thro
ObjectInput in = CodecSupport.deserialize(channel.getUrl(), is, proto);
Object data;
if (req.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeHeartbeatData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else if (req.isEvent()) {
data = decodeEventData(channel, in);
byte[] eventPayload = CodecSupport.getPayload(is);
data = decodeEventData(channel,
CodecSupport.deserialize(channel.getUrl(), new ByteArrayInputStream(eventPayload), proto), eventPayload);
} else {
data = decodeRequestData(channel, in);
}
Expand Down Expand Up @@ -340,15 +348,6 @@ protected Object decodeData(ObjectInput in) throws IOException {
return decodeRequestData(in);
}

@Deprecated
protected Object decodeHeartbeatData(ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}

protected Object decodeRequestData(ObjectInput in) throws IOException {
try {
return in.readObject();
Expand Down Expand Up @@ -392,21 +391,22 @@ protected Object decodeData(Channel channel, ObjectInput in) throws IOException
return decodeRequestData(channel, in);
}

protected Object decodeEventData(Channel channel, ObjectInput in) throws IOException {
protected Object decodeEventData(Channel channel, ObjectInput in, byte[] eventPayload) throws IOException {
try {
int dataLen = eventPayload.length;
int threshold = Integer.parseInt(System.getProperty("deserialization.event.size", "50"));
if (dataLen > threshold) {
throw new IllegalArgumentException("Event data too long, actual size " + dataLen + ", threshold " + threshold + " rejected for security consideration.");
}
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}

@Deprecated
protected Object decodeHeartbeatData(Channel channel, ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
protected Object decodeHeartbeatData(Channel channel, ObjectInput in, byte[] eventPayload) throws IOException {
return decodeEventData(channel, in, eventPayload);
}

protected Object decodeRequestData(Channel channel, ObjectInput in) throws IOException {
Expand Down
Expand Up @@ -24,6 +24,8 @@
import com.alibaba.dubbo.common.utils.NetUtils;
import com.alibaba.dubbo.remoting.Channel;
import com.alibaba.dubbo.remoting.Codec2;
import com.alibaba.dubbo.remoting.exchange.Request;
import com.alibaba.dubbo.remoting.exchange.Response;

import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -51,6 +53,14 @@ protected Serialization getSerialization(Channel channel) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected Serialization getSerialization(Channel channel, Request req) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected Serialization getSerialization(Channel channel, Response res) {
return CodecSupport.getSerialization(channel.getUrl());
}

protected boolean isClientSide(Channel channel) {
String side = (String) channel.getAttribute(Constants.SIDE_KEY);
if ("client".equals(side)) {
Expand Down
Expand Up @@ -24,18 +24,29 @@
import com.alibaba.dubbo.common.logger.LoggerFactory;
import com.alibaba.dubbo.common.serialize.ObjectInput;
import com.alibaba.dubbo.common.serialize.Serialization;
import com.alibaba.dubbo.common.utils.CollectionUtils;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static com.alibaba.dubbo.common.Constants.SERIALIZATION_KEY;

public class CodecSupport {

private static final Logger logger = LoggerFactory.getLogger(CodecSupport.class);
private static Map<Byte, Serialization> ID_SERIALIZATION_MAP = new HashMap<Byte, Serialization>();
private static Map<Byte, String> ID_SERIALIZATIONNAME_MAP = new HashMap<Byte, String>();
private static Map<String, Byte> SERIALIZATIONNAME_ID_MAP = new HashMap<String, Byte>();

private static Map<String, Set<Byte>> PROVIDER_SUPPORTED_SERIALIZATION = new ConcurrentHashMap<String, Set<Byte>>();

static {
Set<String> supportedExtensions = ExtensionLoader.getExtensionLoader(Serialization.class).getSupportedExtensions();
Expand All @@ -51,6 +62,7 @@ public class CodecSupport {
}
ID_SERIALIZATION_MAP.put(idByte, serialization);
ID_SERIALIZATIONNAME_MAP.put(idByte, name);
SERIALIZATIONNAME_ID_MAP.put(name, idByte);
}
}

Expand All @@ -63,23 +75,72 @@ public static Serialization getSerializationById(Byte id) {

public static Serialization getSerialization(URL url) {
return ExtensionLoader.getExtensionLoader(Serialization.class).getExtension(
url.getParameter(Constants.SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION));
url.getParameter(SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION));
}

public static Serialization getSerialization(URL url, Byte id) throws IOException {
Serialization serialization = getSerializationById(id);
String serializationName = url.getParameter(Constants.SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION);
// Check if "serialization id" passed from network matches the id on this side(only take effect for JDK serialization), for security purpose.
if (serialization == null
|| ((id == 3 || id == 7 || id == 4) && !(serializationName.equals(ID_SERIALIZATIONNAME_MAP.get(id))))) {
throw new IOException("Unexpected serialization id:" + id + " received from network, please check if the peer send the right id.");
Serialization result = getSerializationById(id);
if (result == null) {
throw new IOException("Unrecognized serialize type from consumer: " + id);
}
return serialization;
return result;
}

public static ObjectInput deserialize(URL url, InputStream is, byte proto) throws IOException {
Serialization s = getSerialization(url, proto);
return s.deserialize(url, is);
}

/**
* Read all payload to byte[]
*
* @param is
* @return
* @throws IOException
*/
public static byte[] getPayload(InputStream is) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
while ((len = is.read(buffer)) > -1) {
baos.write(buffer, 0, len);
}
baos.flush();
return baos.toByteArray();
}

public static Byte getIDByName(String name) {
return SERIALIZATIONNAME_ID_MAP.get(name);
}

public static void checkSerialization(String path, String version, Byte id) throws IOException {
Set<Byte> supportedSerialization = PROVIDER_SUPPORTED_SERIALIZATION.get(path + ":" + version);
if (Constants.DEFAULT_VERSION.equals(version) && CollectionUtils.isEmpty(supportedSerialization)) {
supportedSerialization = PROVIDER_SUPPORTED_SERIALIZATION.get(path);
}
if (CollectionUtils.isEmpty(supportedSerialization)) {
if (logger.isWarnEnabled()) {
logger.warn("Serialization security check is enabled but cannot work as expected because " +
"there's no matched provider model for path " + path + ", version " + version);
}
} else {
if (!supportedSerialization.contains(id)) {
throw new IOException("Unexpected serialization id:" + id + " received from network, please check if the peer send the right id.");
}
}
}

public static void addProviderSupportedSerialization(String serviceName, List<URL> exportedUrls) {
if (CollectionUtils.isNotEmpty(exportedUrls)) {
Set<Byte> supportedSerialization = new HashSet<Byte>();
for (URL url : exportedUrls) {
String serializationName = url.getParameter(SERIALIZATION_KEY, Constants.DEFAULT_REMOTING_SERIALIZATION);
Byte localId = SERIALIZATIONNAME_ID_MAP.get(serializationName);
supportedSerialization.add(localId);
}
PROVIDER_SUPPORTED_SERIALIZATION.put(serviceName, Collections.unmodifiableSet(supportedSerialization));
}
}


}
Expand Up @@ -230,12 +230,14 @@ public void test_Decode_Return_Request_Event_Object() throws IOException {
Person person = new Person();
byte[] request = getRequestBytes(person, header);

System.setProperty("deserialization.event.size", "100");
Request obj = (Request) decode(request);
Assert.assertEquals(person, obj.getData());
Assert.assertEquals(true, obj.isTwoWay());
Assert.assertEquals(true, obj.isEvent());
Assert.assertEquals(Version.getProtocolVersion(), obj.getVersion());
System.out.println(obj);
System.clearProperty("deserialization.event.size");
}

@Test
Expand Down Expand Up @@ -269,7 +271,7 @@ public void test_Decode_Return_Request_Heartbeat_Object() throws IOException {
@Test
public void test_Decode_Return_Request_Object() throws IOException {
//|10011111|20-stats=ok|id=0|length=0
byte[] header = new byte[]{MAGIC_HIGH, MAGIC_LOW, (byte) 0xe2, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
byte[] header = new byte[]{MAGIC_HIGH, MAGIC_LOW, (byte) 0xc2, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Person person = new Person();
byte[] request = getRequestBytes(person, header);

Expand Down
5 changes: 5 additions & 0 deletions dubbo-rpc/dubbo-rpc-api/pom.xml
Expand Up @@ -39,5 +39,10 @@
<artifactId>dubbo-serialization-api</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dubbo-remoting-api</artifactId>
<version>${project.parent.version}</version>
</dependency>
</dependencies>
</project>
Expand Up @@ -83,4 +83,9 @@ public interface Invocation {
*/
Invoker<?> getInvoker();

Object put(Object key, Object value);

Object get(Object key);

Map<Object, Object> getAttributes();
}
Expand Up @@ -44,6 +44,8 @@ public class RpcInvocation implements Invocation, Serializable {

private transient Invoker<?> invoker;

private Map<Object, Object> attributes = new HashMap<Object, Object>(2);

public RpcInvocation() {
}

Expand Down Expand Up @@ -204,11 +206,26 @@ public String getAttachment(String key, String defaultValue) {
return value;
}

@Override
public Object put(Object key, Object value) {
return attributes.put(key, value);
}

@Override
public Object get(Object key) {
return attributes.get(key);
}

@Override
public Map<Object, Object> getAttributes() {
return attributes;
}

@Override
public String toString() {
return "RpcInvocation [methodName=" + methodName + ", parameterTypes="
+ Arrays.toString(parameterTypes) + ", arguments=" + Arrays.toString(arguments)
+ ", attachments=" + attachments + "]";
+ ", attachments=" + attachments + ", attributes=" + attributes + "]";
}

}