Skip to content

Commit

Permalink
NodePattern: Add support for multiple terms unions
Browse files Browse the repository at this point in the history
  • Loading branch information
marcandre committed Sep 11, 2020
1 parent 9516b3f commit 6e4e354
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 12 deletions.
11 changes: 9 additions & 2 deletions docs/modules/ROOT/pages/node_pattern.adoc
Expand Up @@ -173,7 +173,7 @@ You can add `+...+` before the closing bracket to allow for additional parameter
This will match both our examples, but not `sum(1.0, 2)` or `sum(2)`,
since the first node in the brackets is found, but not the second (`int`).

== `{}` for "OR"
== `{}` for "OR" (union)

Lets make it a bit more complex and introduce floats:

Expand All @@ -185,7 +185,14 @@ $ ruby-parse -e '1.0'
(float 1.0)
----

* `({int float} _)` - int or float types, no matter the value
* `({int | float} _)` - int or float types, no matter the value

Branches of the union can contain more than one term:

* `(array {int int | range})` - matches an array with two integers or a single range element

If all the branches have a single term, you can omit the `|`, so `{int | float}` can be
simplified to `{int float}`.

== `[]` for "AND"

Expand Down
18 changes: 18 additions & 0 deletions lib/rubocop/ast/node_pattern/builder.rb
Expand Up @@ -31,6 +31,24 @@ def emit_call(type, selector, args = nil)
n(type, [selector, *arg_nodes])
end

def emit_union(begin_t, pattern_lists, end_t)
if pattern_lists.size == 1 # {a b c} => [[a, b, c]] => [a, b, c]
children = pattern_lists.first
raise NodePattern::Invalid, 'A union can not be empty' if children.empty?
else # { a b | c } => [[a, b], [c]] => [s(:subsequence, a, b), c]
children = pattern_lists.map do |list|
emit_subsequence(list)
end
end
emit_list(:union, begin_t, children, end_t)
end

def emit_subsequence(node_list)
return node_list.first if node_list.size == 1 # Don't put a single child in a subsequence

emit_list(:subsequence, nil, node_list, nil)
end

private

def n(type, *args)
Expand Down
92 changes: 87 additions & 5 deletions lib/rubocop/ast/node_pattern/compiler/sequence_subcompiler.rb
Expand Up @@ -16,6 +16,9 @@ class Compiler
# rubocop:disable Metrics/ClassLength
class SequenceSubcompiler < Subcompiler
DELTA = 1
POSITIVE = :positive?.to_proc
private_constant :POSITIVE

# Calls `compile_sequence`; the actual `compile` method
# will be used for the different terms of the sequence.
# The only case of re-entrant call to `compile` is `visit_capture`
Expand Down Expand Up @@ -98,6 +101,21 @@ def visit_any_order
end
end

def visit_union
return visit_other_type if node.arity == 1

# The way we implement complex unions is by "forking", i.e.
# making a copy of the present subcompiler to compile each branch
# of the union.
# We then use the resulting state of the subcompilers to
# reset ourselves.
forks = compile_union_forks
preserve_union_start(forks)
merge_forks!(forks)
expr = forks.values.join(" || \n")
"(#{expr})"
end

def compile_case(when_branches, else_code)
<<~RUBY
case
Expand Down Expand Up @@ -229,6 +247,12 @@ def sync
yield code
end

# @api private
attr_reader :in_sync, :cur_index

public :in_sync
protected :cur_index, :compile_terms, :sync

# @return [Array<Range>] total arities (as Ranges) of remaining children nodes
# E.g. For sequence `(_ _? <_ _>)`, arities are: 1, 0..1, 2
# and remaining arities are: 3..4, 2..3, 2..2, 0..0
Expand Down Expand Up @@ -259,7 +283,22 @@ def compile_min_check
end

def compile_remaining
"#{@seq_var}.children.size - #{@cur_index_var}"
offset = case @cur_index
when :seq_head
' + 1'
when :variadic_mode
" - #{@cur_index_var}"
when 0
''
when POSITIVE
" - #{@cur_index}"
else
# odd compiling condition, result may not be expected
# E.g: `(... {a | b c})` => the b c branch can never match
return - (@cur_index + DELTA)
end

"#{@seq_var}.children.size #{offset}"
end

def compile_max_matched
Expand Down Expand Up @@ -320,15 +359,58 @@ def compile_loop(term)
end

def compile_child_nb_guard(arity_range)
# The -1 are because of seq_head
case arity_range.max
when Float::INFINITY
"#{@seq_var}.children.size >= #{arity_range.begin - 1}"
"#{compile_remaining} >= #{arity_range.begin}"
when arity_range.begin
"#{@seq_var}.children.size == #{arity_range.begin - 1}"
"#{compile_remaining} == #{arity_range.begin}"
else
"(#{arity_range.begin - 1}..#{arity_range.max - 1}).cover?(#{@seq_var}.children.size)"
"(#{arity_range.begin}..#{arity_range.max}).cover?(#{compile_remaining})"
end
end

# @return [Hash] of {subcompiler => code}
def compile_union_forks
compiler.each_union(node.children).to_h do |child|
subsequence_terms = child.is_a?(Node::Subsequence) ? child.children : [child]
fork = dup
code = fork.compile_terms(subsequence_terms, @remaining_arity)
@in_sync = false if @cur_index != :variadic_mode
[fork, code]
end
end

# Modifies in place `forks` to insure that `cur_{child|index}_var` are ok
def preserve_union_start(forks)
return if @cur_index != :variadic_mode || forks.size <= 1

compiler.with_temp_variables do |union_reset|
cur = "(#{union_reset} = [#{@cur_child_var}, #{@cur_index_var}]) && "
reset = "(#{@cur_child_var}, #{@cur_index_var} = #{union_reset}) && "
forks.transform_values! do |code|
code = "#{cur}#{code}"
cur = reset
code
end
end
end

# Modifies in place `forks`
# Syncs our state
def merge_forks!(forks)
sub_compilers = forks.keys
if !node.variadic? # e.g {a b | c d}
@cur_index = sub_compilers.first.cur_index # all cur_index should be equivalent
elsif use_index_from_end
# nothing to do
else
# can't use index from end, so we must sync all forks
@cur_index = :variadic_mode
forks.each do |sub, code|
sub.sync { |sync_code| forks[sub] = "#{code} && #{sync_code}" }
end
end
@in_sync = sub_compilers.all?(&:in_sync)
end
end
# rubocop:enable Metrics/ClassLength
Expand Down
2 changes: 1 addition & 1 deletion lib/rubocop/ast/node_pattern/lexer.rex
Expand Up @@ -20,7 +20,7 @@ rules
/[-+]?\d+\.\d+/ { emit :tNUMBER, &:to_f }
/[-+]?\d+/ { emit :tNUMBER, &:to_i }
/#{Regexp.union(
%w"( ) { } [ ] < > $ ! ^ ` ... + * ? ,"
%w"( ) { | } [ ] < > $ ! ^ ` ... + * ? ,"
)}/o { emit ss.matched, &:to_sym }
/%([A-Z:][a-zA-Z_:]+)/ { emit :tPARAM_CONST }
/%([a-z_]+)/ { emit :tPARAM_NAMED }
Expand Down
52 changes: 51 additions & 1 deletion lib/rubocop/ast/node_pattern/node.rb
Expand Up @@ -61,6 +61,10 @@ def arity_range
a.is_a?(Range) ? a : INT_TO_RANGE[a]
end

def with(type: @type, children: @children, location: @location)
self.class.new(type, children, { location: location })
end

INT_TO_RANGE = Hash.new { |h, k| h[k] = k..k }
private_constant :INT_TO_RANGE

Expand Down Expand Up @@ -184,6 +188,50 @@ def arity
end
end

# A list (potentially empty) of nodes; part of a Union
class Subsequence < Node
include ForbidInSeqHead

def arity
min, max = children.map(&:arity_range).map(&:minmax).transpose.map(&:sum)
min == max ? min || 0 : min..max # Note: || 0 for empty case, where min == max == nil.
end

def in_sequence_head
super if children.empty?

return unless (replace = children.first.in_sequence_head)

[with(children: [*replace, *children[1..-1]])]
end
end

# Node class for `{ ... }`
class Union < Node
def arity
minima, maxima = children.map(&:arity_range).map(&:minmax).transpose
min = minima.min
max = maxima.max
min == max ? min : min..max
end

def in_sequence_head
return unless children.any?(&:in_sequence_head)

new_children = children.map do |child|
next child unless (replace = child.in_sequence_head)

if replace.size > 1
Subsequence.new(:subsequence, replace, loc: child.loc)
else
replace.first
end
end

[with(children: new_children)]
end
end

# Registry
MAP = Hash.new(Node).merge!(
sequence: Sequence,
Expand All @@ -192,7 +240,9 @@ def arity
capture: Capture,
predicate: Predicate,
any_order: AnyOrder,
function_call: FunctionCall
function_call: FunctionCall,
subsequence: Subsequence,
union: Union
).freeze
end
end
Expand Down
11 changes: 10 additions & 1 deletion lib/rubocop/ast/node_pattern/parser.rb
Expand Up @@ -46,9 +46,18 @@ def inspect

private

def_delegators :@builder, :emit_list, :emit_unary_op, :emit_atom, :emit_capture, :emit_call
def_delegators :@builder, :emit_list, :emit_unary_op, :emit_atom, :emit_capture,
:emit_call, :emit_union
def_delegators :@lexer, :next_token

def enforce_unary(node)
return node if node.arity == 1

detail = node.loc&.expression&.source || node.to_s
raise NodePattern::Invalid, 'parse error, expected unary node pattern ' \
"but got expression matching multiple elements: #{detail}"
end

# Overrides Racc::Parser's method:
def on_error(token, val, _vstack)
detail = token_to_str(token) || '?'
Expand Down
19 changes: 17 additions & 2 deletions lib/rubocop/ast/node_pattern/parser.y
Expand Up @@ -4,8 +4,12 @@ token tSYMBOL tNUMBER tSTRING tWILDCARD tPARAM_NAMED tPARAM_CONST tPARAM_NUMBER
tFUNCTION_CALL tPREDICATE tNODE_TYPE tARG_LIST tUNIFY
rule
node_pattern # @return Node
: node_pattern_no_union
| union { enforce_unary(val[0]) }
;

node_pattern_no_union # @return Node
: '(' variadic_pattern_list ')' { emit_list :sequence, *val }
| '{' node_pattern_list '}' { emit_list :union, *val }
| '[' node_pattern_list ']' { emit_list :intersection, *val }
| '!' node_pattern { emit_unary_op :negation, *val }
| '^' node_pattern { emit_unary_op :ascend, *val }
Expand All @@ -28,8 +32,13 @@ rule
| tUNIFY { emit_atom :unify, *val }
;

union # @return Node
: '{' separated_variadic_patterns '}' { emit_union(*val) }
;

variadic_pattern # @return Node
: node_pattern
: node_pattern_no_union
| union
| node_pattern repetition
{
main, repeat_t = val
Expand Down Expand Up @@ -84,4 +93,10 @@ rule
: variadic_pattern { val }
| variadic_pattern_list variadic_pattern { val[0] << val[1] }
;

separated_variadic_patterns # @return Array<Array<Node>>
: { [[]] }
| separated_variadic_patterns variadic_pattern { val[0].last << val[1]; val[0] }
| separated_variadic_patterns '|' { val[0] << [] }
;
end
13 changes: 13 additions & 0 deletions spec/rubocop/ast/node_pattern/parser_spec.rb
Expand Up @@ -48,6 +48,19 @@
)
end

it 'expands ... in sequence head deep inside unions' do
rest = s(:rest, :'...')
expect_parsing(
s(:sequence, s(:union,
s(:node_type, :a),
s(:subsequence, s(:node_type, :b), rest),
s(:subsequence, s(:wildcard), rest, s(:node_type, :c)),
s(:subsequence, s(:wildcard), s(:capture, rest)))),
'({a | b ... | ... c | $...})',
''
)
end

it 'generates specialized nodes' do
source_file = Parser::Source::Buffer.new('(spec)', source: '($_)')
ast = parser.parse(source_file)
Expand Down

0 comments on commit 6e4e354

Please sign in to comment.