Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid race condition with unprepared_statement #36871

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -15,28 +15,30 @@ def to_sql(arel_or_sql_string, binds = [])
end

def to_sql_and_binds(arel_or_sql_string, binds = []) # :nodoc:
if arel_or_sql_string.respond_to?(:ast)
unless binds.empty?
raise "Passing bind parameters with an arel AST is forbidden. " \
"The values must be stored on the AST directly"
end
@lock.synchronize do
if arel_or_sql_string.respond_to?(:ast)
unless binds.empty?
raise "Passing bind parameters with an arel AST is forbidden. " \
"The values must be stored on the AST directly"
end

if prepared_statements
sql, binds = visitor.compile(arel_or_sql_string.ast, collector)
if prepared_statements
sql, binds = visitor.compile(arel_or_sql_string.ast, collector)

if binds.length > bind_params_length
unprepared_statement do
sql, binds = to_sql_and_binds(arel_or_sql_string)
visitor.preparable = false
if binds.length > bind_params_length
unprepared_statement do
sql, binds = to_sql_and_binds(arel_or_sql_string)
visitor.preparable = false
end
end
else
sql = visitor.compile(arel_or_sql_string.ast, collector)
end
[sql.freeze, binds]
else
sql = visitor.compile(arel_or_sql_string.ast, collector)
visitor.preparable = false if prepared_statements
[arel_or_sql_string.dup.freeze, binds]
end
[sql.freeze, binds]
else
visitor.preparable = false if prepared_statements
[arel_or_sql_string.dup.freeze, binds]
end
end
private :to_sql_and_binds
Expand All @@ -57,17 +59,19 @@ def cacheable_query(klass, arel) # :nodoc:

# Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil)
arel = arel_from_relation(arel)
sql, binds = to_sql_and_binds(arel, binds)
@lock.synchronize do
arel = arel_from_relation(arel)
sql, binds = to_sql_and_binds(arel, binds)

if preparable.nil?
preparable = prepared_statements ? visitor.preparable : false
end
if preparable.nil?
preparable = prepared_statements ? visitor.preparable : false
end

if prepared_statements && preparable
select_prepared(sql, name, binds)
else
select(sql, name, binds)
if prepared_statements && preparable
select_prepared(sql, name, binds)
else
select(sql, name, binds)
end
end
end

Expand Down
Expand Up @@ -94,35 +94,35 @@ def clear_query_cache
end

def select_all(arel, name = nil, binds = [], preparable: nil)
if @query_cache_enabled && !locked?(arel)
arel = arel_from_relation(arel)
sql, binds = to_sql_and_binds(arel, binds)
@lock.synchronize do
if @query_cache_enabled && !locked?(arel)
arel = arel_from_relation(arel)
sql, binds = to_sql_and_binds(arel, binds)

if preparable.nil?
preparable = prepared_statements ? visitor.preparable : false
end
if preparable.nil?
preparable = prepared_statements ? visitor.preparable : false
end

cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable) }
else
super
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable) }
else
super
end
end
end

private
def cache_sql(sql, name, binds)
@lock.synchronize do
result =
if @query_cache[sql].key?(binds)
ActiveSupport::Notifications.instrument(
"sql.active_record",
cache_notification_info(sql, name, binds)
)
@query_cache[sql][binds]
else
@query_cache[sql][binds] = yield
end
result.dup
end
result =
if @query_cache[sql].key?(binds)
ActiveSupport::Notifications.instrument(
"sql.active_record",
cache_notification_info(sql, name, binds)
)
@query_cache[sql][binds]
else
@query_cache[sql][binds] = yield
end
result.dup
end

# Database adapters can override this method to
Expand Down
Expand Up @@ -258,10 +258,12 @@ def seconds_idle # :nodoc:
end

def unprepared_statement
old_prepared_statements, @prepared_statements = @prepared_statements, false
yield
ensure
@prepared_statements = old_prepared_statements
@lock.synchronize do
old_prepared_statements, @prepared_statements = @prepared_statements, false
yield
ensure
@prepared_statements = old_prepared_statements
end
end

# Returns the human-readable name of the adapter. Use mixed case - one
Expand Down
Expand Up @@ -6,10 +6,12 @@ module MySQL
module DatabaseStatements
# Returns an ActiveRecord::Result instance.
def select_all(*) # :nodoc:
result = if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super }
else
super
result = @lock.synchronize do
if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super }
else
super
end
end
@connection.abandon_results!
result
Expand Down
31 changes: 18 additions & 13 deletions activerecord/lib/active_record/core.rb
Expand Up @@ -168,11 +168,14 @@ def find(*ids) # :nodoc:

key = primary_key

statement = cached_find_by_statement(key) { |params|
where(key => params.bind).limit(1)
}
record = connection.lock.synchronize do
statement = cached_find_by_statement(key) { |params|
where(key => params.bind).limit(1)
}

statement.execute([id], connection)&.first
end

record = statement.execute([id], connection)&.first
unless record
raise RecordNotFound.new("Couldn't find #{name} with '#{key}'=#{id}", name, key, id)
end
Expand All @@ -194,16 +197,18 @@ def find_by(*args) # :nodoc:

keys = hash.keys

statement = cached_find_by_statement(keys) { |params|
wheres = keys.each_with_object({}) { |param, o|
o[param] = params.bind
connection.lock.synchronize do
statement = cached_find_by_statement(keys) { |params|
wheres = keys.each_with_object({}) { |param, o|
o[param] = params.bind
}
where(wheres).limit(1)
}
where(wheres).limit(1)
}
begin
statement.execute(hash.values, connection)&.first
rescue TypeError
raise ActiveRecord::StatementInvalid
begin
statement.execute(hash.values, connection)&.first
rescue TypeError
raise ActiveRecord::StatementInvalid
end
end
end

Expand Down