Skip to content

Commit

Permalink
Simplify code by using attribute instead of argument
Browse files Browse the repository at this point in the history
  • Loading branch information
marcandre committed Apr 9, 2019
1 parent 1ec76b9 commit f95803a
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions lib/rubocop/node_pattern.rb
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Compiler
REST = '...'.freeze
CAPTURED_REST = '$...'.freeze

attr_reader :match_code
attr_reader :match_code, :tokens

SEQ_HEAD_INDEX = -1

Expand All @@ -138,15 +138,15 @@ def initialize(str, node_var = 'node0')
end

def run(node_var)
tokens = Compiler.tokens(@string)
@tokens = Compiler.tokens(@string)

@match_code = compile_expr(tokens, node_var, false)
@match_code = compile_expr(node_var, false)

fail_due_to('unbalanced pattern') unless tokens.empty?
end

# rubocop:disable Metrics/MethodLength, Metrics/AbcSize
def compile_expr(tokens, cur_node, seq_head)
def compile_expr(cur_node, seq_head)
# read a single pattern-matching expression from the token stream,
# return Ruby code which performs the corresponding matching operation
# on 'cur_node' (which is Ruby code which evaluates to an AST node)
Expand All @@ -155,16 +155,16 @@ def compile_expr(tokens, cur_node, seq_head)
# contains an arbitrary number of sub-expressions
token = tokens.shift
case token
when '(' then compile_seq(tokens, cur_node, seq_head)
when '{' then compile_union(tokens, cur_node, seq_head)
when '[' then compile_intersect(tokens, cur_node, seq_head)
when '!' then compile_negation(tokens, cur_node, seq_head)
when '$' then compile_capture(tokens, cur_node, seq_head)
when '^' then compile_ascend(tokens, cur_node)
when '(' then compile_seq(cur_node, seq_head)
when '{' then compile_union(cur_node, seq_head)
when '[' then compile_intersect(cur_node, seq_head)
when '!' then compile_negation(cur_node, seq_head)
when '$' then compile_capture(cur_node, seq_head)
when '^' then compile_ascend(cur_node)
when WILDCARD then compile_wildcard(cur_node, token[1..-1], seq_head)
when FUNCALL then compile_funcall(tokens, cur_node, token, seq_head)
when FUNCALL then compile_funcall(cur_node, token, seq_head)
when LITERAL then compile_literal(cur_node, token, seq_head)
when PREDICATE then compile_predicate(tokens, cur_node, token, seq_head)
when PREDICATE then compile_predicate(cur_node, token, seq_head)
when NODE then compile_nodetype(cur_node, token)
when PARAM then compile_param(cur_node, token[1..-1], seq_head)
when CLOSING then fail_due_to("#{token} in invalid position")
Expand All @@ -174,7 +174,7 @@ def compile_expr(tokens, cur_node, seq_head)
end
# rubocop:enable Metrics/MethodLength, Metrics/AbcSize

def compile_seq(tokens, cur_node, seq_head)
def compile_seq(cur_node, seq_head)
fail_due_to('empty parentheses') if tokens.first == ')'
fail_due_to('parentheses at sequence head') if seq_head

Expand All @@ -183,7 +183,7 @@ def compile_seq(tokens, cur_node, seq_head)
# to be safe, cache the node in a temp variable and then use the
# temp variable as 'cur_node'
with_temp_node(cur_node) do |init, temp_node|
terms = compile_seq_terms(tokens, temp_node)
terms = compile_seq_terms(temp_node)
terms.unshift(compile_guard_clause(temp_node))

join_terms(init, terms, " &&\n")
Expand All @@ -194,24 +194,24 @@ def compile_guard_clause(cur_node)
"#{cur_node}.is_a?(RuboCop::AST::Node)"
end

def compile_seq_terms(tokens, cur_node)
def compile_seq_terms(cur_node)
ret =
compile_seq_terms_with_size(tokens, cur_node) do |token, terms, index|
compile_seq_terms_with_size(cur_node) do |token, terms, index|
capture = next_capture if token == CAPTURED_REST
if capture || token == REST
index = 0 if index == SEQ_HEAD_INDEX # Consider ($...) as (_ $...)
return compile_ellipsis(tokens, cur_node, terms, index, capture)
return compile_ellipsis(cur_node, terms, index, capture)
end
end
ret << "(#{cur_node}.children.size == #{ret.size - 1})"
end

def compile_seq_terms_with_size(tokens, cur_node)
def compile_seq_terms_with_size(cur_node)
index = SEQ_HEAD_INDEX
terms = []
until tokens.first == ')'
yield tokens.first, terms, index
term = compile_expr_with_index(tokens, cur_node, index)
term = compile_expr_with_index(cur_node, index)
index += 1
terms << term
end
Expand All @@ -220,21 +220,21 @@ def compile_seq_terms_with_size(tokens, cur_node)
terms
end

def compile_expr_with_index(tokens, cur_node, index)
def compile_expr_with_index(cur_node, index)
if index == SEQ_HEAD_INDEX
# in 'sequence head' position; some expressions are compiled
# differently at 'sequence head' (notably 'node type' expressions)
# grep for seq_head to see where it makes a difference
compile_expr(tokens, cur_node, true)
compile_expr(cur_node, true)
else
child_node = "#{cur_node}.children[#{index}]"
compile_expr(tokens, child_node, false)
compile_expr(child_node, false)
end
end

def compile_ellipsis(tokens, cur_node, terms, index, capture = nil)
def compile_ellipsis(cur_node, terms, index, capture = nil)
tokens.shift # drop ellipsis
tail = compile_seq_tail(tokens, cur_node)
tail = compile_seq_tail(cur_node)
terms << "(#{cur_node}.children.size >= #{index + tail.size})"
terms.concat tail
if capture
Expand All @@ -244,37 +244,36 @@ def compile_ellipsis(tokens, cur_node, terms, index, capture = nil)
terms
end

def compile_seq_tail(tokens, cur_node)
def compile_seq_tail(cur_node)
child_node = "#{cur_node}.children[%<revindex>i]"
terms = []
until tokens.first == ')'
terms << compile_expr(tokens, child_node, false)
terms << compile_expr(child_node, false)
end
tokens.shift # drop ')'
# E.g. for terms.size == 3, we want to replace the three [%<revindex>i]
# with [-3], [-2] and [-1]
terms.map.with_index { |term, i| format term, revindex: i - terms.size }
end

def compile_union(tokens, cur_node, seq_head)
def compile_union(cur_node, seq_head)
fail_due_to('empty union') if tokens.first == '}'

with_temp_node(cur_node) do |init, temp_node|
terms = union_terms(tokens, temp_node, seq_head)
terms = union_terms(temp_node, seq_head)
join_terms(init, terms, ' || ')
end
end

def union_terms(tokens, temp_node, seq_head)
def union_terms(temp_node, seq_head)
# we need to ensure that each branch of the {} contains the same
# number of captures (since only one branch of the {} can actually
# match, the same variables are used to hold the captures for each
# branch)
compile_expr_with_captures(tokens,
temp_node, seq_head) do |term, before, after|
compile_expr_with_captures(temp_node, seq_head) do |term, before, after|
terms = [term]
until tokens.first == '}'
terms << compile_expr_with_capture_check(tokens, temp_node,
terms << compile_expr_with_capture_check(temp_node,
seq_head, before, after)
end
tokens.shift
Expand All @@ -283,50 +282,50 @@ def union_terms(tokens, temp_node, seq_head)
end
end

def compile_expr_with_captures(tokens, temp_node, seq_head)
def compile_expr_with_captures(temp_node, seq_head)
captures_before = @captures
expr = compile_expr(tokens, temp_node, seq_head)
expr = compile_expr(temp_node, seq_head)

yield expr, captures_before, @captures
end

def compile_expr_with_capture_check(tokens, temp_node, seq_head, before,
def compile_expr_with_capture_check(temp_node, seq_head, before,
after)
@captures = before
expr = compile_expr(tokens, temp_node, seq_head)
expr = compile_expr(temp_node, seq_head)
if @captures != after
fail_due_to('each branch of {} must have same # of captures')
end

expr
end

def compile_intersect(tokens, cur_node, seq_head)
def compile_intersect(cur_node, seq_head)
fail_due_to('empty intersection') if tokens.first == ']'

with_temp_node(cur_node) do |init, temp_node|
terms = []
until tokens.first == ']'
terms << compile_expr(tokens, temp_node, seq_head)
terms << compile_expr(temp_node, seq_head)
end
tokens.shift

join_terms(init, terms, ' && ')
end
end

def compile_capture(tokens, cur_node, seq_head)
def compile_capture(cur_node, seq_head)
"(#{next_capture} = #{cur_node}#{'.type' if seq_head}; " \
"#{compile_expr(tokens, cur_node, seq_head)})"
"#{compile_expr(cur_node, seq_head)})"
end

def compile_negation(tokens, cur_node, seq_head)
"(!#{compile_expr(tokens, cur_node, seq_head)})"
def compile_negation(cur_node, seq_head)
"(!#{compile_expr(cur_node, seq_head)})"
end

def compile_ascend(tokens, cur_node)
def compile_ascend(cur_node)
"(#{cur_node}.parent && " \
"#{compile_expr(tokens, "#{cur_node}.parent", false)})"
"#{compile_expr("#{cur_node}.parent", false)})"
end

def compile_wildcard(cur_node, name, seq_head)
Expand All @@ -349,7 +348,7 @@ def compile_literal(cur_node, literal, seq_head)
"(#{cur_node}#{'.type' if seq_head} == #{literal})"
end

def compile_predicate(tokens, cur_node, predicate, seq_head)
def compile_predicate(cur_node, predicate, seq_head)
if predicate.end_with?('(') # is there an arglist?
args = compile_args(tokens)
predicate = predicate[0..-2] # drop the trailing (
Expand All @@ -359,7 +358,7 @@ def compile_predicate(tokens, cur_node, predicate, seq_head)
end
end

def compile_funcall(tokens, cur_node, method, seq_head)
def compile_funcall(cur_node, method, seq_head)
# call a method in the context which this pattern-matching
# code is used in. pass target value as an argument
method = method[1..-1] # drop the leading #
Expand Down

0 comments on commit f95803a

Please sign in to comment.