Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply patch #10547

Merged
merged 1 commit into from Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/google/protobuf/extension_set_inl.h
Expand Up @@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
const char* ptr, const Msg* containing_type,
internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
std::string payload;
uint32 type_id = 0;
bool payload_read = false;
uint32 type_id;
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (!ctx->Done(&ptr)) {
uint32 tag = static_cast<uint8>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64 tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
ExtensionInfo extension;
bool was_packed_on_wire;
if (!FindExtension(2, type_id, containing_type, ctx, &extension,
Expand All @@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id != 0) {
if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
containing_type, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
type_id = 0;
state = State::kDone;
} else {
std::string tmp;
int32 size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
if (state == State::kNoTag) {
payload = std::move(tmp);
state = State::kHasPayload;
}
}
} else {
ptr = ReadTag(ptr - 1, &tag);
Expand Down
26 changes: 18 additions & 8 deletions src/google/protobuf/wire_format.cc
Expand Up @@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg);
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

std::string payload;
uint32 type_id = 0;
bool payload_read = false;
while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers.
Expand All @@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
uint64 tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
const FieldDescriptor* field;
if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id);
Expand All @@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id == 0) {
if (state == State::kNoTag) {
int32 size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
} else {
state = State::kHasPayload;
} else if (state == State::kHasType) {
// We're now parsing the payload
const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) {
Expand All @@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection,
field);
type_id = 0;
state = State::kDone;
} else {
int32 size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->Skip(ptr, size);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
}
} else {
// An unknown field in MessageSetItem.
Expand Down
27 changes: 18 additions & 9 deletions src/google/protobuf/wire_format_lite.h
Expand Up @@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
// we can parse it later.
std::string message_data;

enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (true) {
const uint32 tag = input->ReadTagNoLastTag();
if (tag == 0) return false;
Expand All @@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
case WireFormatLite::kMessageSetTypeIdTag: {
uint32 type_id;
if (!input->ReadVarint32(&type_id)) return false;
last_type_id = type_id;

if (!message_data.empty()) {
if (state == State::kNoTag) {
last_type_id = type_id;
state = State::kHasType;
} else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
reinterpret_cast<const uint8*>(message_data.data()),
static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget());
if (!ms.ParseField(last_type_id, &sub_input)) {
if (!ms.ParseField(type_id, &sub_input)) {
return false;
}
message_data.clear();
state = State::kDone;
}

break;
}

case WireFormatLite::kMessageSetMessageTag: {
if (last_type_id == 0) {
if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
state = State::kDone;
} else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
uint32 length;
if (!input->ReadVarint32(&length)) return false;
Expand All @@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false;
state = State::kHasPayload;
} else {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
if (!ms.SkipField(tag, input)) return false;
}

break;
Expand Down
109 changes: 98 additions & 11 deletions src/google/protobuf/wire_format_unittest.cc
Expand Up @@ -46,6 +46,7 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/wire_format_lite.h>
#include <google/protobuf/testing/googletest.h>
#include <google/protobuf/stubs/logging.h>
Expand Down Expand Up @@ -585,41 +586,72 @@ TEST(WireFormatTest, ParseMessageSet) {
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}

TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
namespace {
std::string BuildMessageSetItemStart() {
std::string data;
{
unittest::TestMessageSetExtension1 message;
message.set_i(123);
// Build a MessageSet manually with its message content put before its
// type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
}
return data;
}
std::string BuildMessageSetItemEnd() {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
}
return data;
}
std::string BuildMessageSetTestExtension1(int value = 123) {
std::string data;
{
unittest::TestMessageSetExtension1 message;
message.set_i(value);
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output);
// Write the type id.
uint32 type_id = message.GetDescriptor()->extension(0)->number();
}
return data;
}
std::string BuildMessageSetItemTypeId(int extension_number) {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
type_id, &coded_output);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
extension_number, &coded_output);
}
return data;
}
void ValidateTestMessageSet(const std::string& test_case,
const std::string& data) {
SCOPED_TRACE(test_case);
{
proto2_wireformat_unittest::TestMessageSet message_set;
::proto2_wireformat_unittest::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));

EXPECT_EQ(123,
message_set
.GetExtension(
unittest::TestMessageSetExtension1::message_set_extension)
.i());

// Make sure it does not contain anything else.
message_set.ClearExtension(
unittest::TestMessageSetExtension1::message_set_extension);
EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
proto2_wireformat_unittest::TestMessageSet message_set;
::proto2_wireformat_unittest::TestMessageSet message_set;
io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()),
data.size());
EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set));
Expand All @@ -631,6 +663,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
unittest::TestMessageSetExtension1::message_set_extension)
.i());
}
{
// Test parse the message via DynamicMessage.
DynamicMessageFactory factory;
std::unique_ptr<Message> msg(
factory
.GetPrototype(
::proto2_wireformat_unittest::TestMessageSet::descriptor())
->New());
msg->ParseFromString(data);
auto* reflection = msg->GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(*msg, &fields);
ASSERT_EQ(fields.size(), 1);
const auto& sub = reflection->GetMessage(*msg, fields[0]);
reflection = sub.GetReflection();
EXPECT_EQ(123, reflection->GetInt32(
sub, sub.GetDescriptor()->FindFieldByName("i")));
}
}
} // namespace

TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string message = BuildMessageSetTestExtension1();

ValidateTestMessageSet("id + message", start + id + message + end);
ValidateTestMessageSet("message + id", start + message + id + end);
}

TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string other_id = BuildMessageSetItemTypeId(123456);
std::string message = BuildMessageSetTestExtension1();
std::string other_message = BuildMessageSetTestExtension1(321);

// Double id
ValidateTestMessageSet("id + other_id + message",
start + id + other_id + message + end);
ValidateTestMessageSet("id + message + other_id",
start + id + message + other_id + end);
ValidateTestMessageSet("message + id + other_id",
start + message + id + other_id + end);
// Double message
ValidateTestMessageSet("id + message + other_message",
start + id + message + other_message + end);
ValidateTestMessageSet("message + id + other_message",
start + message + id + other_message + end);
ValidateTestMessageSet("message + other_message + id",
start + message + other_message + id + end);
}

void SerializeReverseOrder(
Expand Down