Skip to content

Commit

Permalink
Fix handling non null-terminated string_views in LookupByKey (#8203)
Browse files Browse the repository at this point in the history
* Reproduce the error in a unit test

Reproduces #8200

* Overload KeyCompareWithValue to work for string-like objects

This fixes #8200.

* Extra tests

---------

Co-authored-by: Derek Bailey <derekbailey@google.com>
  • Loading branch information
mpawlowski-eyeo and dbaileychess committed Mar 25, 2024
1 parent 67eb95d commit 0cfb7eb
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 0 deletions.
42 changes: 42 additions & 0 deletions include/flatbuffers/reflection_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ struct KeyValue FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_key) const {
return strcmp(key()->c_str(), _key);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _key) const {
if (key()->c_str() < _key) return -1;
if (_key < key()->c_str()) return 1;
return 0;
}
const ::flatbuffers::String *value() const {
return GetPointer<const ::flatbuffers::String *>(VT_VALUE);
}
Expand Down Expand Up @@ -464,6 +470,12 @@ struct Enum FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::EnumVal>> *values() const {
return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::EnumVal>> *>(VT_VALUES);
}
Expand Down Expand Up @@ -616,6 +628,12 @@ struct Field FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const reflection::Type *type() const {
return GetPointer<const reflection::Type *>(VT_TYPE);
}
Expand Down Expand Up @@ -834,6 +852,12 @@ struct Object FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::Field>> *fields() const {
return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::Field>> *>(VT_FIELDS);
}
Expand Down Expand Up @@ -986,6 +1010,12 @@ struct RPCCall FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const reflection::Object *request() const {
return GetPointer<const reflection::Object *>(VT_REQUEST);
}
Expand Down Expand Up @@ -1102,6 +1132,12 @@ struct Service FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::RPCCall>> *calls() const {
return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<reflection::RPCCall>> *>(VT_CALLS);
}
Expand Down Expand Up @@ -1221,6 +1257,12 @@ struct SchemaFile FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_filename) const {
return strcmp(filename()->c_str(), _filename);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _filename) const {
if (filename()->c_str() < _filename) return -1;
if (_filename < filename()->c_str()) return 1;
return 0;
}
/// Names of included files, relative to project root.
const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *included_filenames() const {
return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_INCLUDED_FILENAMES);
Expand Down
12 changes: 12 additions & 0 deletions src/idl_gen_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2417,8 +2417,20 @@ class CppGenerator : public BaseGenerator {

// Generate KeyCompareWithValue function
if (is_string) {
// Compares key against a null-terminated char array.
code_ += " int KeyCompareWithValue(const char *_{{FIELD_NAME}}) const {";
code_ += " return strcmp({{FIELD_NAME}}()->c_str(), _{{FIELD_NAME}});";
code_ += " }";
// Compares key against any string-like object (e.g. std::string_view or
// std::string) that implements operator< comparison with const char*.
code_ += " template<typename StringType>";
code_ +=
" int KeyCompareWithValue(const StringType& _{{FIELD_NAME}}) const "
"{";
code_ +=
" if ({{FIELD_NAME}}()->c_str() < _{{FIELD_NAME}}) return -1;";
code_ += " if (_{{FIELD_NAME}} < {{FIELD_NAME}}()->c_str()) return 1;";
code_ += " return 0;";
} else if (is_array) {
const auto &elem_type = field.value.type.VectorType();
std::string input_type = "::flatbuffers::Array<" +
Expand Down
6 changes: 6 additions & 0 deletions tests/cpp17/generated_cpp17/monster_test_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<uint8_t> *inventory() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_INVENTORY);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/key_field/key_field_sample_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,12 @@ struct FooTable FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_c) const {
return strcmp(c()->c_str(), _c);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _c) const {
if (c()->c_str() < _c) return -1;
if (_c < c()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<const keyfield::sample::Baz *> *d() const {
return GetPointer<const ::flatbuffers::Vector<const keyfield::sample::Baz *> *>(VT_D);
}
Expand Down
27 changes: 27 additions & 0 deletions tests/monster_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,33 @@ void AccessFlatBufferTest(const uint8_t *flatbuf, size_t length, bool pooled) {
TEST_NOTNULL(vecoftables->LookupByKey("Fred"));
TEST_NOTNULL(vecoftables->LookupByKey("Wilma"));

// Verify the same objects are returned for char*-based and string-based
// lookups.
TEST_EQ(vecoftables->LookupByKey("Barney"),
vecoftables->LookupByKey(std::string("Barney")));
TEST_EQ(vecoftables->LookupByKey("Fred"),
vecoftables->LookupByKey(std::string("Fred")));
TEST_EQ(vecoftables->LookupByKey("Wilma"),
vecoftables->LookupByKey(std::string("Wilma")));

#ifdef FLATBUFFERS_HAS_STRING_VIEW
// Tests for LookupByKey with a key that is a truncated
// version of a longer, invalid key.
const std::string invalid_key = "Barney123";
std::string_view valid_truncated_key = invalid_key;
valid_truncated_key.remove_suffix(3); // "Barney"
TEST_NOTNULL(vecoftables->LookupByKey(valid_truncated_key));
TEST_EQ(vecoftables->LookupByKey("Barney"),
vecoftables->LookupByKey(valid_truncated_key));

// Tests for LookupByKey with a key that is a truncated
// version of a longer, valid key.
const std::string valid_key = "Barney";
std::string_view invalid_truncated_key = valid_key;
invalid_truncated_key.remove_suffix(3); // "Bar"
TEST_NULL(vecoftables->LookupByKey(invalid_truncated_key));
#endif // FLATBUFFERS_HAS_STRING_VIEW

// Test accessing a vector of sorted structs
auto vecofstructs = monster->testarrayofsortedstruct();
if (vecofstructs) { // not filled in monster_test.bfbs
Expand Down
6 changes: 6 additions & 0 deletions tests/monster_test_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<uint8_t> *inventory() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_INVENTORY);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/monster_test_suffix/ext_only/monster_test_generated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<uint8_t> *inventory() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_INVENTORY);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<uint8_t> *inventory() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_INVENTORY);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/monster_test_suffix/monster_test_suffix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,12 @@ struct Monster FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int KeyCompareWithValue(const char *_name) const {
return strcmp(name()->c_str(), _name);
}
template<typename StringType>
int KeyCompareWithValue(const StringType& _name) const {
if (name()->c_str() < _name) return -1;
if (_name < name()->c_str()) return 1;
return 0;
}
const ::flatbuffers::Vector<uint8_t> *inventory() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_INVENTORY);
}
Expand Down

0 comments on commit 0cfb7eb

Please sign in to comment.