Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEBT: Enforce thread-safety for ONNX opset schema API #5291

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,11 @@ class OpSchemaRegistry final : public ISchemaRegistry {
OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
ONNX_TRY {
op_schema.Finalize();

// Acquires lock to thread-guard schema map access
auto* registry = OpSchemaRegistry::Instance();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think every time the map returned by OpSchemaRegistry::Instance() is accessed should be protected against concurrent accesses. Function GetMapWithoutEnsuringRegistration is called by another function not protected by the mutex. Maybe the datarace comes from it. Maybe it would be safer to create a specific class to hold the map storing the schemas and protect the accesses to this class instead of looking to every place this container is accessed and protect it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the function called not protected by the mutex? I assume that this is a dynamic race detection in the given unit-test, where only registration/deregistration are called in parallel, both of which seem to invoke the function under a mutex?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetMapWithoutEnsuringRegistration is used in other places to access the registered schemas. If a schema is removed while another function is counting the number of schemas, it could lead to some unstable state. It is unlikely to happen but it could happen in a mutlithread scenarios. Every call to GetMapWithoutEnsuringRegistration should be protected with a lock.

std::unique_lock<std::mutex> lock(registry->schema_map_mutex_);

auto& m = GetMapWithoutEnsuringRegistration();
auto& op_name = op_schema.Name();
auto& op_domain = op_schema.domain();
Expand Down Expand Up @@ -1310,6 +1315,7 @@ class OpSchemaRegistry final : public ISchemaRegistry {
// Deregister all ONNX opset schemas from domain
// Domain with default value ONNX_DOMAIN means ONNX.
static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
std::unique_lock<std::mutex> lock(Instance()->schema_map_mutex_);
auto& schema_map = GetMapWithoutEnsuringRegistration();
// schema_map stores operator schemas in the format of
// <OpName, <Domain, <OperatorSetVersion, OpSchema>>>
Expand Down Expand Up @@ -1384,6 +1390,9 @@ class OpSchemaRegistry final : public ISchemaRegistry {
// within this class
OpSchemaRegistry() = default;

// Allows OpSchemaRegisterOnce to use Instance()
friend OpSchemaRegisterOnce;

/**
* @brief Returns the underlying string to OpSchema map.
*
Expand All @@ -1398,6 +1407,10 @@ class OpSchemaRegistry final : public ISchemaRegistry {
static OpName_Domain_Version_Schema_Map& map();
static int loaded_schema_version;

// To be accessed via singleton Instance()->schema_map_mutex_
// to thread-guard the schema map access
std::mutex schema_map_mutex_;

public:
static const std::vector<OpSchema> get_all_schemas_with_history() {
std::vector<OpSchema> r;
Expand Down