Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply patch #10542

Merged
merged 2 commits into from Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/google/protobuf/extension_set_inl.h
Expand Up @@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
internal::ParseContext* ctx) {
std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
uint32_t type_id;
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (!ctx->Done(&ptr)) {
uint32_t tag = static_cast<uint8_t>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
ExtensionInfo extension;
bool was_packed_on_wire;
if (!FindExtension(2, type_id, extendee, ctx, &extension,
Expand All @@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id != 0) {
if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
extendee, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
type_id = 0;
state = State::kDone;
} else {
std::string tmp;
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
if (state == State::kNoTag) {
payload = std::move(tmp);
state = State::kHasPayload;
}
}
} else {
ptr = ReadTag(ptr - 1, &tag);
Expand Down
26 changes: 18 additions & 8 deletions src/google/protobuf/wire_format.cc
Expand Up @@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg);
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers.
Expand All @@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
const FieldDescriptor* field;
if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id);
Expand All @@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id == 0) {
if (state == State::kNoTag) {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
} else {
state = State::kHasPayload;
} else if (state == State::kHasType) {
// We're now parsing the payload
const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) {
Expand All @@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
field);
type_id = 0;
state = State::kDone;
} else {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->Skip(ptr, size);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
}
} else {
// An unknown field in MessageSetItem.
Expand Down
27 changes: 18 additions & 9 deletions src/google/protobuf/wire_format_lite.h
Expand Up @@ -1845,6 +1845,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
// we can parse it later.
std::string message_data;

enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (true) {
const uint32_t tag = input->ReadTagNoLastTag();
if (tag == 0) return false;
Expand All @@ -1853,26 +1856,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
case WireFormatLite::kMessageSetTypeIdTag: {
uint32_t type_id;
if (!input->ReadVarint32(&type_id)) return false;
last_type_id = type_id;

if (!message_data.empty()) {
if (state == State::kNoTag) {
last_type_id = type_id;
state = State::kHasType;
} else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
reinterpret_cast<const uint8_t*>(message_data.data()),
static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget());
if (!ms.ParseField(last_type_id, &sub_input)) {
if (!ms.ParseField(type_id, &sub_input)) {
return false;
}
message_data.clear();
state = State::kDone;
}

break;
}

case WireFormatLite::kMessageSetMessageTag: {
if (last_type_id == 0) {
if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
state = State::kDone;
} else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
uint32_t length;
if (!input->ReadVarint32(&length)) return false;
Expand All @@ -1883,11 +1894,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false;
state = State::kHasPayload;
} else {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
if (!ms.SkipField(tag, input)) return false;
}

break;
Expand Down
105 changes: 96 additions & 9 deletions src/google/protobuf/wire_format_unittest.inc
Expand Up @@ -49,6 +49,7 @@
#include <google/protobuf/stubs/casts.h>
#include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stl_util.h>
#include <google/protobuf/dynamic_message.h>

// clang-format off
#include <google/protobuf/port_def.inc>
Expand Down Expand Up @@ -580,28 +581,54 @@ TEST(WireFormatTest, ParseMessageSet) {
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}

TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
namespace {
std::string BuildMessageSetItemStart() {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(123);
// Build a MessageSet manually with its message content put before its
// type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
}
return data;
}
std::string BuildMessageSetItemEnd() {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
}
return data;
}
std::string BuildMessageSetTestExtension1(int value = 123) {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(value);
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output);
// Write the type id.
uint32 type_id = message.GetDescriptor()->extension(0)->number();
}
return data;
}
std::string BuildMessageSetItemTypeId(int extension_number) {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
type_id, &coded_output);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
extension_number, &coded_output);
}
return data;
}
void ValidateTestMessageSet(const std::string& test_case,
const std::string& data) {
SCOPED_TRACE(test_case);
{
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));
Expand All @@ -611,6 +638,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
.GetExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());

// Make sure it does not contain anything else.
message_set.ClearExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension);
EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
Expand All @@ -626,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());
}
{
// Test parse the message via DynamicMessage.
DynamicMessageFactory factory;
std::unique_ptr<Message> msg(
factory
.GetPrototype(
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
->New());
msg->ParseFromString(data);
auto* reflection = msg->GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(*msg, &fields);
ASSERT_EQ(fields.size(), 1);
const auto& sub = reflection->GetMessage(*msg, fields[0]);
reflection = sub.GetReflection();
EXPECT_EQ(123, reflection->GetInt32(
sub, sub.GetDescriptor()->FindFieldByName("i")));
}
}
} // namespace

TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string message = BuildMessageSetTestExtension1();

ValidateTestMessageSet("id + message", start + id + message + end);
ValidateTestMessageSet("message + id", start + message + id + end);
}

TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string other_id = BuildMessageSetItemTypeId(123456);
std::string message = BuildMessageSetTestExtension1();
std::string other_message = BuildMessageSetTestExtension1(321);

// Double id
ValidateTestMessageSet("id + other_id + message",
start + id + other_id + message + end);
ValidateTestMessageSet("id + message + other_id",
start + id + message + other_id + end);
ValidateTestMessageSet("message + id + other_id",
start + message + id + other_id + end);
// Double message
ValidateTestMessageSet("id + message + other_message",
start + id + message + other_message + end);
ValidateTestMessageSet("message + id + other_message",
start + message + id + other_message + end);
ValidateTestMessageSet("message + other_message + id",
start + message + other_message + id + end);
}

void SerializeReverseOrder(
Expand Down