Skip to content

Commit

Permalink
Treewalker: Improve performance by avoiding memory allocations (#309)
Browse files Browse the repository at this point in the history
Previously the treewalker would unnecessary call "to_h" on each Protobuf
message in the parsetree, in order to get the field names to walk. This
caused unnecessary copies of the message, increasing memory usage and
slowing down the tree walk. Instead, use the Protobuf descriptor and its
field descriptors to walk the message.

Additionally this also optimizes the case where a block with 1 argument
is used for the tree walk, since we don't need to handle the location,
avoiding unnecessary copies of the field name string.

Together these changes result in about a 5x speed up in some use cases.
  • Loading branch information
lfittl committed Jan 29, 2024
1 parent 34789ad commit f23f1df
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 14 deletions.
2 changes: 1 addition & 1 deletion lib/pg_query/param_refs.rb
Expand Up @@ -3,7 +3,7 @@ class ParserResult
def param_refs # rubocop:disable Metrics/CyclomaticComplexity
results = []

treewalker! @tree do |_, _, node, location|
treewalker_with_location! @tree do |_, _, node, location|
case node
when PgQuery::ParamRef
# Ignore param refs inside type casts, as these are already handled
Expand Down
49 changes: 38 additions & 11 deletions lib/pg_query/treewalker.rb
Expand Up @@ -5,15 +5,17 @@ class ParserResult
# If you pass a block with 1 argument, you will get each node.
# If you pass a block with 4 arguments, you will get each parent_node, parent_field, node and location.
#
# If sufficient for the use case, the 1 argument block approach is recommended, since it's faster.
#
# Location uniquely identifies a given node within the parse tree. This is a stable identifier across
# multiple parser runs, assuming the same pg_query release and no modifications to the parse tree.
def walk!(&block)
if block.arity == 1
treewalker!(@tree) do |_, _, node, _|
treewalker!(@tree) do |node|
yield(node)
end
else
treewalker!(@tree) do |parent_node, parent_field, node, location|
treewalker_with_location!(@tree) do |parent_node, parent_field, node, location|
yield(parent_node, parent_field, node, location)
end
end
Expand All @@ -22,28 +24,53 @@ def walk!(&block)
private

def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity
nodes = [tree.dup]

loop do
parent_node = nodes.shift

case parent_node
when Google::Protobuf::MessageExts
parent_node.class.descriptor.each do |field_descriptor|
node = field_descriptor.get(parent_node)
next if node.nil?
yield(node) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField)
nodes << node
end
when Google::Protobuf::RepeatedField
parent_node.each do |node|
next if node.nil?
yield(node) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField)
nodes << node
end
end

break if nodes.empty?
end
end

def treewalker_with_location!(tree) # rubocop:disable Metrics/CyclomaticComplexity
nodes = [[tree.dup, []]]

loop do
parent_node, parent_location = nodes.shift

case parent_node
when Google::Protobuf::MessageExts
parent_node.to_h.keys.each do |parent_field|
node = parent_node[parent_field.to_s]
parent_node.class.descriptor.each do |field_descriptor|
parent_field = field_descriptor.name
node = parent_node[parent_field]
next if node.nil?
location = parent_location + [parent_field]
yield(parent_node, parent_field, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField)

nodes << [node, location] unless node.nil?
location = parent_location + [parent_field.to_sym]
yield(parent_node, parent_field.to_sym, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField)
nodes << [node, location]
end
when Google::Protobuf::RepeatedField
parent_node.each_with_index do |node, parent_field|
next if node.nil?
location = parent_location + [parent_field]
yield(parent_node, parent_field, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField)

nodes << [node, location] unless node.nil?
nodes << [node, location]
end
end

Expand All @@ -52,7 +79,7 @@ def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity
end

def find_tree_location(tree, searched_location)
treewalker! tree do |parent_node, parent_field, node, location|
treewalker_with_location! tree do |parent_node, parent_field, node, location|
next unless location == searched_location
yield(parent_node, parent_field, node)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/pg_query/truncate.rb
Expand Up @@ -60,7 +60,7 @@ def truncate(max_length) # rubocop:disable Metrics/CyclomaticComplexity
def find_possible_truncations # rubocop:disable Metrics/CyclomaticComplexity
truncations = []

treewalker! @tree do |node, k, v, location|
treewalker_with_location! @tree do |node, k, v, location|
case k
when :target_list
next unless node.is_a?(PgQuery::SelectStmt) || node.is_a?(PgQuery::UpdateStmt) || node.is_a?(PgQuery::OnConflictClause)
Expand Down
2 changes: 1 addition & 1 deletion spec/lib/treewalker_spec.rb
@@ -1,6 +1,6 @@
require 'spec_helper'

describe PgQuery, '.treewalker' do
describe PgQuery, '#walk!' do
it 'walks nodes contained in repeated fields' do
locations = []
described_class.parse("SELECT to_timestamp($1)").walk! do |_, _, _, location|
Expand Down

0 comments on commit f23f1df

Please sign in to comment.