From d885714e710b7a467b5c6b69d6001b972630fcaf Mon Sep 17 00:00:00 2001 From: Marc-Andre Lafortune Date: Fri, 11 Sep 2020 13:17:48 -0400 Subject: [PATCH] NodePattern: Add support for multiple terms unions --- docs/modules/ROOT/pages/node_pattern.adoc | 11 ++- lib/rubocop/ast/node_pattern/builder.rb | 18 ++++ .../compiler/sequence_subcompiler.rb | 92 ++++++++++++++++++- lib/rubocop/ast/node_pattern/lexer.rex | 2 +- lib/rubocop/ast/node_pattern/node.rb | 68 +++++++++++++- lib/rubocop/ast/node_pattern/parser.rb | 11 ++- lib/rubocop/ast/node_pattern/parser.y | 19 +++- spec/rubocop/ast/node_pattern/parser_spec.rb | 13 +++ spec/rubocop/ast/node_pattern_spec.rb | 82 +++++++++++++++++ 9 files changed, 304 insertions(+), 12 deletions(-) diff --git a/docs/modules/ROOT/pages/node_pattern.adoc b/docs/modules/ROOT/pages/node_pattern.adoc index acb44db39..1b65c16f5 100644 --- a/docs/modules/ROOT/pages/node_pattern.adoc +++ b/docs/modules/ROOT/pages/node_pattern.adoc @@ -173,7 +173,7 @@ You can add `+...+` before the closing bracket to allow for additional parameter This will match both our examples, but not `sum(1.0, 2)` or `sum(2)`, since the first node in the brackets is found, but not the second (`int`). -== `{}` for "OR" +== `{}` for "OR" (union) Lets make it a bit more complex and introduce floats: @@ -185,7 +185,14 @@ $ ruby-parse -e '1.0' (float 1.0) ---- -* `({int float} _)` - int or float types, no matter the value +* `({int | float} _)` - int or float types, no matter the value + +Branches of the union can contain more than one term: + +* `(array {int int | range})` - matches an array with two integers or a single range element + +If all the branches have a single term, you can omit the `|`, so `{int | float}` can be +simplified to `{int float}`. == `[]` for "AND" diff --git a/lib/rubocop/ast/node_pattern/builder.rb b/lib/rubocop/ast/node_pattern/builder.rb index 5d4212f02..11fd05c13 100644 --- a/lib/rubocop/ast/node_pattern/builder.rb +++ b/lib/rubocop/ast/node_pattern/builder.rb @@ -31,6 +31,24 @@ def emit_call(type, selector, args = nil) n(type, [selector, *arg_nodes]) end + def emit_union(begin_t, pattern_lists, end_t) + if pattern_lists.size == 1 # {a b c} => [[a, b, c]] => [a, b, c] + children = pattern_lists.first + raise NodePattern::Invalid, 'A union can not be empty' if children.empty? + else # { a b | c } => [[a, b], [c]] => [s(:subsequence, a, b), c] + children = pattern_lists.map do |list| + emit_subsequence(list) + end + end + emit_list(:union, begin_t, children, end_t) + end + + def emit_subsequence(node_list) + return node_list.first if node_list.size == 1 # Don't put a single child in a subsequence + + emit_list(:subsequence, nil, node_list, nil) + end + private def n(type, *args) diff --git a/lib/rubocop/ast/node_pattern/compiler/sequence_subcompiler.rb b/lib/rubocop/ast/node_pattern/compiler/sequence_subcompiler.rb index 27e86d5ba..94970e371 100644 --- a/lib/rubocop/ast/node_pattern/compiler/sequence_subcompiler.rb +++ b/lib/rubocop/ast/node_pattern/compiler/sequence_subcompiler.rb @@ -16,6 +16,9 @@ class Compiler # rubocop:disable Metrics/ClassLength class SequenceSubcompiler < Subcompiler DELTA = 1 + POSITIVE = :positive?.to_proc + private_constant :POSITIVE + # Calls `compile_sequence`; the actual `compile` method # will be used for the different terms of the sequence. # The only case of re-entrant call to `compile` is `visit_capture` @@ -98,6 +101,21 @@ def visit_any_order end end + def visit_union + return visit_other_type if node.arity == 1 + + # The way we implement complex unions is by "forking", i.e. + # making a copy of the present subcompiler to compile each branch + # of the union. + # We then use the resulting state of the subcompilers to + # reset ourselves. + forks = compile_union_forks + preserve_union_start(forks) + merge_forks!(forks) + expr = forks.values.join(" || \n") + "(#{expr})" + end + def compile_case(when_branches, else_code) <<~RUBY case @@ -229,6 +247,12 @@ def sync yield code end + # @api private + attr_reader :in_sync, :cur_index + + public :in_sync + protected :cur_index, :compile_terms, :sync + # @return [Array] total arities (as Ranges) of remaining children nodes # E.g. For sequence `(_ _? <_ _>)`, arities are: 1, 0..1, 2 # and remaining arities are: 3..4, 2..3, 2..2, 0..0 @@ -259,7 +283,22 @@ def compile_min_check end def compile_remaining - "#{@seq_var}.children.size - #{@cur_index_var}" + offset = case @cur_index + when :seq_head + ' + 1' + when :variadic_mode + " - #{@cur_index_var}" + when 0 + '' + when POSITIVE + " - #{@cur_index}" + else + # odd compiling condition, result may not be expected + # E.g: `(... {a | b c})` => the b c branch can never match + return - (@cur_index + DELTA) + end + + "#{@seq_var}.children.size #{offset}" end def compile_max_matched @@ -320,15 +359,58 @@ def compile_loop(term) end def compile_child_nb_guard(arity_range) - # The -1 are because of seq_head case arity_range.max when Float::INFINITY - "#{@seq_var}.children.size >= #{arity_range.begin - 1}" + "#{compile_remaining} >= #{arity_range.begin}" when arity_range.begin - "#{@seq_var}.children.size == #{arity_range.begin - 1}" + "#{compile_remaining} == #{arity_range.begin}" + else + "(#{arity_range.begin}..#{arity_range.max}).cover?(#{compile_remaining})" + end + end + + # @return [Hash] of {subcompiler => code} + def compile_union_forks + compiler.each_union(node.children).map do |child| + subsequence_terms = child.is_a?(Node::Subsequence) ? child.children : [child] + fork = dup + code = fork.compile_terms(subsequence_terms, @remaining_arity) + @in_sync = false if @cur_index != :variadic_mode + [fork, code] + end.to_h # we could avoid map if RUBY_VERSION >= 2.6... + end + + # Modifies in place `forks` to insure that `cur_{child|index}_var` are ok + def preserve_union_start(forks) + return if @cur_index != :variadic_mode || forks.size <= 1 + + compiler.with_temp_variables do |union_reset| + cur = "(#{union_reset} = [#{@cur_child_var}, #{@cur_index_var}]) && " + reset = "(#{@cur_child_var}, #{@cur_index_var} = #{union_reset}) && " + forks.transform_values! do |code| + code = "#{cur}#{code}" + cur = reset + code + end + end + end + + # Modifies in place `forks` + # Syncs our state + def merge_forks!(forks) + sub_compilers = forks.keys + if !node.variadic? # e.g {a b | c d} + @cur_index = sub_compilers.first.cur_index # all cur_index should be equivalent + elsif use_index_from_end + # nothing to do else - "(#{arity_range.begin - 1}..#{arity_range.max - 1}).cover?(#{@seq_var}.children.size)" + # can't use index from end, so we must sync all forks + @cur_index = :variadic_mode + forks.each do |sub, code| + sub.sync { |sync_code| forks[sub] = "#{code} && #{sync_code}" } + end end + @in_sync = sub_compilers.all?(&:in_sync) end end # rubocop:enable Metrics/ClassLength diff --git a/lib/rubocop/ast/node_pattern/lexer.rex b/lib/rubocop/ast/node_pattern/lexer.rex index 03e2b6261..29079b4fd 100644 --- a/lib/rubocop/ast/node_pattern/lexer.rex +++ b/lib/rubocop/ast/node_pattern/lexer.rex @@ -20,7 +20,7 @@ rules /[-+]?\d+\.\d+/ { emit :tNUMBER, &:to_f } /[-+]?\d+/ { emit :tNUMBER, &:to_i } /#{Regexp.union( - %w"( ) { } [ ] < > $ ! ^ ` ... + * ? ," + %w"( ) { | } [ ] < > $ ! ^ ` ... + * ? ," )}/o { emit ss.matched, &:to_sym } /%([A-Z:][a-zA-Z_:]+)/ { emit :tPARAM_CONST } /%([a-z_]+)/ { emit :tPARAM_NAMED } diff --git a/lib/rubocop/ast/node_pattern/node.rb b/lib/rubocop/ast/node_pattern/node.rb index 86f9ab758..97fc2aead 100644 --- a/lib/rubocop/ast/node_pattern/node.rb +++ b/lib/rubocop/ast/node_pattern/node.rb @@ -61,6 +61,10 @@ def arity_range a.is_a?(Range) ? a : INT_TO_RANGE[a] end + def with(type: @type, children: @children, location: @location) + self.class.new(type, children, { location: location }) + end + INT_TO_RANGE = Hash.new { |h, k| h[k] = k..k } private_constant :INT_TO_RANGE @@ -184,6 +188,66 @@ def arity end end + # Circumvent broken `Range#minmax` for infinity ranges in 2.6- + module MapMinMax + if RUBY_VERSION >= '2.7' + def map_min_max(enum) + enum.map(&:minmax) + end + else + def map_min_max(enum) + enum.map { |r| [r.min, r.max] } + end + end + end + + # A list (potentially empty) of nodes; part of a Union + class Subsequence < Node + include ForbidInSeqHead + include MapMinMax + + def arity + min, max = map_min_max(children.map(&:arity_range)).transpose.map(&:sum) + min == max ? min || 0 : min..max # Note: || 0 for empty case, where min == max == nil. + end + + def in_sequence_head + super if children.empty? + + return unless (replace = children.first.in_sequence_head) + + [with(children: [*replace, *children[1..-1]])] + end + end + + # Node class for `{ ... }` + class Union < Node + include MapMinMax + + def arity + minima, maxima = map_min_max(children.map(&:arity_range)).transpose + min = minima.min + max = maxima.max + min == max ? min : min..max + end + + def in_sequence_head + return unless children.any?(&:in_sequence_head) + + new_children = children.map do |child| + next child unless (replace = child.in_sequence_head) + + if replace.size > 1 + Subsequence.new(:subsequence, replace, loc: child.loc) + else + replace.first + end + end + + [with(children: new_children)] + end + end + # Registry MAP = Hash.new(Node).merge!( sequence: Sequence, @@ -192,7 +256,9 @@ def arity capture: Capture, predicate: Predicate, any_order: AnyOrder, - function_call: FunctionCall + function_call: FunctionCall, + subsequence: Subsequence, + union: Union ).freeze end end diff --git a/lib/rubocop/ast/node_pattern/parser.rb b/lib/rubocop/ast/node_pattern/parser.rb index 7a72d662e..899ae4b93 100644 --- a/lib/rubocop/ast/node_pattern/parser.rb +++ b/lib/rubocop/ast/node_pattern/parser.rb @@ -43,9 +43,18 @@ def inspect private - def_delegators :@builder, :emit_list, :emit_unary_op, :emit_atom, :emit_capture, :emit_call + def_delegators :@builder, :emit_list, :emit_unary_op, :emit_atom, :emit_capture, + :emit_call, :emit_union def_delegators :@lexer, :next_token + def enforce_unary(node) + return node if node.arity == 1 + + detail = node.loc&.expression&.source || node.to_s + raise NodePattern::Invalid, 'parse error, expected unary node pattern ' \ + "but got expression matching multiple elements: #{detail}" + end + # Overrides Racc::Parser's method: def on_error(token, val, _vstack) detail = token_to_str(token) || '?' diff --git a/lib/rubocop/ast/node_pattern/parser.y b/lib/rubocop/ast/node_pattern/parser.y index 14387069e..56de0abea 100644 --- a/lib/rubocop/ast/node_pattern/parser.y +++ b/lib/rubocop/ast/node_pattern/parser.y @@ -4,8 +4,12 @@ token tSYMBOL tNUMBER tSTRING tWILDCARD tPARAM_NAMED tPARAM_CONST tPARAM_NUMBER tFUNCTION_CALL tPREDICATE tNODE_TYPE tARG_LIST tUNIFY rule node_pattern # @return Node + : node_pattern_no_union + | union { enforce_unary(val[0]) } + ; + + node_pattern_no_union # @return Node : '(' variadic_pattern_list ')' { emit_list :sequence, *val } - | '{' node_pattern_list '}' { emit_list :union, *val } | '[' node_pattern_list ']' { emit_list :intersection, *val } | '!' node_pattern { emit_unary_op :negation, *val } | '^' node_pattern { emit_unary_op :ascend, *val } @@ -28,8 +32,13 @@ rule | tUNIFY { emit_atom :unify, *val } ; + union # @return Node + : '{' separated_variadic_patterns '}' { emit_union(*val) } + ; + variadic_pattern # @return Node - : node_pattern + : node_pattern_no_union + | union | node_pattern repetition { main, repeat_t = val @@ -84,4 +93,10 @@ rule : variadic_pattern { val } | variadic_pattern_list variadic_pattern { val[0] << val[1] } ; + + separated_variadic_patterns # @return Array> + : { [[]] } + | separated_variadic_patterns variadic_pattern { val[0].last << val[1]; val[0] } + | separated_variadic_patterns '|' { val[0] << [] } + ; end diff --git a/spec/rubocop/ast/node_pattern/parser_spec.rb b/spec/rubocop/ast/node_pattern/parser_spec.rb index 28dc6d305..05877d154 100644 --- a/spec/rubocop/ast/node_pattern/parser_spec.rb +++ b/spec/rubocop/ast/node_pattern/parser_spec.rb @@ -48,6 +48,19 @@ ) end + it 'expands ... in sequence head deep inside unions' do + rest = s(:rest, :'...') + expect_parsing( + s(:sequence, s(:union, + s(:node_type, :a), + s(:subsequence, s(:node_type, :b), rest), + s(:subsequence, s(:wildcard), rest, s(:node_type, :c)), + s(:subsequence, s(:wildcard), s(:capture, rest)))), + '({a | b ... | ... c | $...})', + '' + ) + end + it 'generates specialized nodes' do source_file = Parser::Source::Buffer.new('(spec)', source: '($_)') ast = parser.parse(source_file) diff --git a/spec/rubocop/ast/node_pattern_spec.rb b/spec/rubocop/ast/node_pattern_spec.rb index 4e6fe1c81..d08f171f4 100644 --- a/spec/rubocop/ast/node_pattern_spec.rb +++ b/spec/rubocop/ast/node_pattern_spec.rb @@ -663,6 +663,70 @@ def not_match_codes(*codes) it { expect(pattern).to match_code(node) } end + + context 'variadic' do + context 'with fixed terms' do + it 'works for cases with fixed arity before and after union' do + expect('(_ { int | sym _ str | } const)').to match_codes( + '[X]', '[42, X]', '[:foo, //, "bar", X]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]' + ) + end + + it 'works for cases with variadic terms after union' do + expect('(_ { int | sym _ str | } const+)').to match_codes( + '[X]', '[42, X, Y, Z]', '[:foo, //, "bar", X]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]' + ) + end + + it 'works for cases with variadic terms before and after union' do + expect('(_ const ? { int | sym _ str | } const+)').to match_codes( + '[X]', '[FOO, 42, X, Y, Z]', '[:foo, //, "bar", X]', '[X, Y, Z]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]', '[FOO BAR, 42]' + ) + end + end + + context 'with variadic terms' do + it 'works for cases with fixed arity before and after union' do + expect('(_ { sym+ _ str | int* } const)').to match_codes( + '[X]', '[42, 666, X]', '[:foo, :foo2, //, "bar", X]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]' + ) + end + + it 'works for cases with variadic terms after union' do + expect('(_ { sym+ _ str | int* } const+)').to match_codes( + '[X]', '[42, 666, X, Y, Z]', '[:foo, :foo2, //, "bar", X]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]' + ) + end + + it 'works for cases with variadic terms before and after union' do + expect('(_ const ? { sym+ _ str | int* } const+)').to match_codes( + '[X]', '[FOO, 42, 666, X, Y, Z]', '[:foo, :foo2, //, "bar", X]', '[X, Y, Z]' + ).and not_match_codes( + '[42]', '[4.2, X]', '["bar", //, :foo, X]', '[FOO BAR, 42]' + ) + end + end + + context 'multiple' do + it 'works for complex cases' do + expect('(_ const ? { sym+ int+ | int+ sym+ } { str+ | regexp+ } ... )').to match_codes( + '[X, :foo, :bar, 42, "a", Y]', '[42, 666, :foo, //]' + ).and not_match_codes( + '[42, :almost, X]', '[X, 42, :foo, 42, //]', '[X, :foo, //, :foo, X]' + ) + end + end + end end describe 'captures on a wildcard' do @@ -1845,6 +1909,24 @@ def withargs(foo, bar, qux) it_behaves_like 'invalid' end + context 'with empty union subsequence in seq head' do + let(:pattern) { '({foo|})' } + + it_behaves_like 'invalid' + end + + context 'with unsupported subsequence in seq head within union' do + let(:pattern) { '({foo bar+})' } + + it_behaves_like 'invalid' + end + + context 'with variadic unions where not supported' do + let(:pattern) { '(_ [_ {foo | ...}])' } + + it_behaves_like 'invalid' + end + context 'with empty intersection' do let(:pattern) { '[]' }