diff --git a/lib/rubocop/ast/node_pattern.rb b/lib/rubocop/ast/node_pattern.rb index 5cb4d3daf..d23a10230 100644 --- a/lib/rubocop/ast/node_pattern.rb +++ b/lib/rubocop/ast/node_pattern.rb @@ -754,6 +754,50 @@ def substitute_cur_node(code, cur_node, first_cur_node: cur_node) def self.tokens(pattern) pattern.scan(TOKEN).reject { |token| token =~ /\A#{SEPARATORS}\Z/ } end + + def def_helper(base, src) + location = caller_locations(3, 1).first + base.class_eval(src, location.path, location.lineno) + end + + def def_node_matcher(base, method_name) + def_helper(base, <<~RUBY) + def #{method_name}(node = self#{emit_trailing_params}) + #{emit_method_code} + end + RUBY + end + + def def_node_search(base, method_name) + def_helper(base, emit_node_search(method_name)) + end + + def emit_node_search(method_name) + if method_name.to_s.end_with?('?') + on_match = 'return true' + else + prelude = <<~RUBY + return enum_for(:#{method_name}, + node0#{emit_trailing_params}) unless block_given? + RUBY + on_match = emit_yield_capture('node') + end + emit_node_search_body(method_name, prelude: prelude, on_match: on_match) + end + + def emit_node_search_body(method_name, prelude:, on_match:) + <<~RUBY + def #{method_name}(node0#{emit_trailing_params}) + #{prelude} + node0.each_node do |node| + if #{match_code} + #{on_match} + end + end + nil + end + RUBY + end end private_constant :Compiler @@ -767,13 +811,7 @@ module Macros # If the node matches, and no block is provided, the new method will # return the captures, or `true` if there were none. def def_node_matcher(method_name, pattern_str) - compiler = Compiler.new(pattern_str, 'node') - src = "def #{method_name}(node = self" \ - "#{compiler.emit_trailing_params});" \ - "#{compiler.emit_method_code};end" - - location = caller_locations(1, 1).first - class_eval(src, location.path, location.lineno) + Compiler.new(pattern_str, 'node').def_node_matcher(self, method_name) end # Define a method which recurses over the descendants of an AST node, @@ -783,49 +821,7 @@ def def_node_matcher(method_name, pattern_str) # as soon as it finds a descendant which matches. Otherwise, it will # yield all descendants which match. def def_node_search(method_name, pattern_str) - compiler = Compiler.new(pattern_str, 'node') - location = caller_locations(1, 1).first - called_from = [location.path, location.lineno] - - if method_name.to_s.end_with?('?') - node_search_first(method_name, compiler, called_from) - else - node_search_all(method_name, compiler, called_from) - end - end - - def node_search_first(method_name, compiler, called_from) - node_search(method_name, compiler, 'return true', '', called_from) - end - - def node_search_all(method_name, compiler, called_from) - yield_code = compiler.emit_yield_capture('node') - prelude = "return enum_for(:#{method_name}, node0" \ - "#{compiler.emit_trailing_params}) unless block_given?" - - node_search(method_name, compiler, yield_code, prelude, called_from) - end - - def node_search(method_name, compiler, on_match, prelude, called_from) - src = node_search_body(method_name, compiler.emit_trailing_params, - prelude, compiler.match_code, on_match) - filename, lineno = *called_from - class_eval(src, filename, lineno.to_i) - end - - def node_search_body(method_name, trailing_params, prelude, match_code, - on_match) - <<~RUBY - def #{method_name}(node0#{trailing_params}) - #{prelude} - node0.each_node do |node| - if #{match_code} - #{on_match} - end - end - nil - end - RUBY + Compiler.new(pattern_str, 'node').def_node_search(self, method_name) end end