Skip to content

Commit

Permalink
Add named parameters to NodePattern
Browse files Browse the repository at this point in the history
  • Loading branch information
marcandre committed Jun 26, 2020
1 parent 4377a64 commit 5d1659e
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 25 deletions.
23 changes: 23 additions & 0 deletions docs/modules/ROOT/pages/node_pattern.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,29 @@ NOTE: `Array#===` will never match a single node element (so don't pass arrays),
but `Set#===` is an alias to `Set#include?` (Ruby 2.5+ only), and so can be
very useful to match within many possible literals / Nodes.

== `%param_name` for named parameters

Arguments can be passed as named parameters. They will be matched using `===`
(see `%` above).

Contrary to positional arguments, defaults values can be passed to
`def_node_matcher` and `def_node_search`:

[source,ruby]
----
def_node_matcher :interesting_call?, '(send _ %method ...)',
method: Set[:transform_values, :transform_keys,
:transform_values!, :transform_keys!,
:to_h].freeze
# Usage:
interesting_call?(node) # use the default methods
interesting_call?(node, method: /^transform/) # match anything starting with 'transform'
----

Named parameters as arguments to custom methods are also supported.

== `nil` or `nil?`

Take a special attention to nil behavior:
Expand Down
85 changes: 63 additions & 22 deletions lib/rubocop/ast/node_pattern.rb
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ module AST
# # for consistency, %0 is the 'root node' which is
# # passed as the 1st argument to #match, where the
# # matching process starts
# '(send _ %named)' # arguments can also be passed as named
# # parameters (see `%1`)
# # Note that the macros `def_node_pattern` and
# # `def_node_search` accept default values for these.
# '^^send' # each ^ ascends one level in the AST
# # so this matches against the grandparent node
# '`send' # descends any number of level in the AST
Expand Down Expand Up @@ -125,10 +129,11 @@ class Compiler
NUMBER = /-?\d+(?:\.\d+)?/.freeze
STRING = /".+?"/.freeze
METHOD_NAME = /\#?#{IDENTIFIER}[!?]?\(?/.freeze
KEYWORD_NAME = /%[a-z_]+/.freeze
PARAM_NUMBER = /%\d*/.freeze

SEPARATORS = /\s+/.freeze
TOKENS = Regexp.union(META, PARAM_NUMBER, NUMBER,
TOKENS = Regexp.union(META, KEYWORD_NAME, PARAM_NUMBER, NUMBER,
METHOD_NAME, SYMBOL, STRING)

TOKEN = /\G(?:#{SEPARATORS}|#{TOKENS}|.)/.freeze
Expand All @@ -140,6 +145,7 @@ class Compiler
FUNCALL = /\A\##{METHOD_NAME}/.freeze
LITERAL = /\A(?:#{SYMBOL}|#{NUMBER}|#{STRING})\Z/.freeze
PARAM = /\A#{PARAM_NUMBER}\Z/.freeze
KEYWORD = /\A#{KEYWORD_NAME}\Z/.freeze
CLOSING = /\A(?:\)|\}|\])\Z/.freeze

REST = '...'
Expand Down Expand Up @@ -198,6 +204,7 @@ def initialize(str, node_var = 'node0')
@captures = 0 # number of captures seen
@unify = {} # named wildcard -> temp variable
@params = 0 # highest % (param) number seen
@keywords = Set[] # keyword parameters seen
run(node_var)
end

Expand Down Expand Up @@ -237,6 +244,7 @@ def compile_expr(token = tokens.shift)
when LITERAL then compile_literal(token)
when PREDICATE then compile_predicate(token)
when NODE then compile_nodetype(token)
when KEYWORD then compile_keyword(token[1..-1])
when PARAM then compile_param(token[1..-1])
when CLOSING then fail_due_to("#{token} in invalid position")
when nil then fail_due_to('pattern ended prematurely')
Expand Down Expand Up @@ -620,6 +628,10 @@ def compile_param(number)
"#{get_param(number)} === #{CUR_ELEMENT}"
end

def compile_keyword(keyword)
"#{get_keyword(keyword)} === #{CUR_ELEMENT}"
end

def compile_args(tokens)
index = tokens.find_index { |token| token == ')' }

Expand All @@ -631,12 +643,13 @@ def compile_args(tokens)
end

def compile_arg(token)
name = token[1..-1]
case token
when WILDCARD then
name = token[1..-1]
when WILDCARD
access_unify(name) || fail_due_to('invalid in arglist: ' + token)
when LITERAL then token
when PARAM then get_param(token[1..-1])
when KEYWORD then get_keyword(name)
when PARAM then get_param(name)
when CLOSING then fail_due_to("#{token} in invalid position")
when nil then fail_due_to('pattern ended prematurely')
else fail_due_to("invalid token in arglist: #{token.inspect}")
Expand All @@ -655,6 +668,11 @@ def get_param(number)
number.zero? ? @root : "param#{number}"
end

def get_keyword(name)
@keywords << name
name
end

def emit_yield_capture(when_no_capture = '')
yield_val = if @captures.zero?
when_no_capture
Expand All @@ -680,9 +698,15 @@ def emit_param_list
(1..@params).map { |n| "param#{n}" }.join(',')
end

def emit_trailing_params
def emit_keyword_list(forwarding: false)
pattern = "%<keyword>s: #{'%<keyword>s' if forwarding}"
@keywords.map { |k| format(pattern, keyword: k) }.join(',')
end

def emit_trailing_params(forwarding: false)
params = emit_param_list
params.empty? ? '' : ",#{params}"
keywords = emit_keyword_list(forwarding: forwarding)
[params, keywords].reject(&:empty?).map { |p| ", #{p}" }.join
end

def emit_method_code
Expand Down Expand Up @@ -759,21 +783,32 @@ def self.tokens(pattern)
pattern.scan(TOKEN).reject { |token| token =~ /\A#{SEPARATORS}\Z/ }
end

def def_helper(base, src)
def def_helper(base, method_name, **defaults)
location = caller_locations(3, 1).first
unless defaults.empty?
base.send :define_method, method_name do |*args, **values|
send method_name, *args, **defaults, **values
end
method_name = :"without_defaults_#{method_name}"
end
src = yield method_name
base.class_eval(src, location.path, location.lineno)
end

def def_node_matcher(base, method_name)
def_helper(base, <<~RUBY)
def #{method_name}(node = self#{emit_trailing_params})
#{emit_method_code}
end
RUBY
def def_node_matcher(base, method_name, **defaults)
def_helper(base, method_name, **defaults) do |name|
<<~RUBY
def #{name}(node = self#{emit_trailing_params})
#{emit_method_code}
end
RUBY
end
end

def def_node_search(base, method_name)
def_helper(base, emit_node_search(method_name))
def def_node_search(base, method_name, **defaults)
def_helper(base, method_name, **defaults) do |name|
emit_node_search(name)
end
end

def emit_node_search(method_name)
Expand All @@ -782,7 +817,7 @@ def emit_node_search(method_name)
else
prelude = <<~RUBY
return enum_for(:#{method_name},
node0#{emit_trailing_params}) unless block_given?
node0#{emit_trailing_params(forwarding: true)}) unless block_given?
RUBY
on_match = emit_yield_capture('node')
end
Expand Down Expand Up @@ -814,8 +849,9 @@ module Macros
# yield to the block (passing any captures as block arguments).
# If the node matches, and no block is provided, the new method will
# return the captures, or `true` if there were none.
def def_node_matcher(method_name, pattern_str)
Compiler.new(pattern_str, 'node').def_node_matcher(self, method_name)
def def_node_matcher(method_name, pattern_str, **keyword_defaults)
Compiler.new(pattern_str, 'node')
.def_node_matcher(self, method_name, **keyword_defaults)
end

# Define a method which recurses over the descendants of an AST node,
Expand All @@ -824,8 +860,9 @@ def def_node_matcher(method_name, pattern_str)
# If the method name ends with '?', the new method will return `true`
# as soon as it finds a descendant which matches. Otherwise, it will
# yield all descendants which match.
def def_node_search(method_name, pattern_str)
Compiler.new(pattern_str, 'node').def_node_search(self, method_name)
def def_node_search(method_name, pattern_str, **keyword_defaults)
Compiler.new(pattern_str, 'node')
.def_node_search(self, method_name, **keyword_defaults)
end
end

Expand All @@ -839,11 +876,15 @@ def initialize(str)
instance_eval(src, __FILE__, __LINE__ + 1)
end

def match(*args)
def match(*args, **rest)
# If we're here, it's because the singleton method has not been defined,
# either because we've been dup'ed or serialized through YAML
initialize(pattern)
match(*args)
if rest.empty?
match(*args)
else
match(*args, **rest)
end
end

def marshal_load(pattern)
Expand Down
126 changes: 123 additions & 3 deletions spec/rubocop/ast/node_pattern_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@

let(:node) { root_node }
let(:params) { [] }
let(:keyword_params) { {} }
let(:instance) { described_class.new(pattern) }
let(:result) { instance.match(node, *params) }
let(:result) do
if keyword_params.empty? # Avoid bug in Ruby < 2.6
instance.match(node, *params)
else
instance.match(node, *params, **keyword_params)
end
end

shared_examples 'matching' do
include RuboCop::AST::Sexp
Expand Down Expand Up @@ -1122,6 +1129,39 @@
end
end

context 'with a named argument' do
let(:pattern) { '(send (int equal?(%param)) ...)' }
let(:ruby) { '1 + 2' }

context 'for which the predicate is true' do
let(:keyword_params) { { param: 1 } }

it_behaves_like 'matching'
end

context 'for which the predicate is false' do
let(:keyword_params) { { param: 2 } }

it_behaves_like 'nonmatching'
end

context 'when not given' do
let(:keyword_params) { {} }

it 'raises an error' do
expect { result }.to raise_error(ArgumentError)
end
end

context 'with extra arguments' do
let(:keyword_params) { { param: 1, extra: 2 } }

it 'raises an error' do
expect { result }.to raise_error(ArgumentError)
end
end
end

context 'with multiple arguments' do
let(:pattern) { '(str between?(%1, %2))' }
let(:ruby) { '"c"' }
Expand Down Expand Up @@ -1160,6 +1200,35 @@
end
end

context 'as named parameters' do
let(:pattern) { '%foo' }
let(:matcher) { Object.new }
let(:keyword_params) { { foo: matcher } }
let(:ruby) { '10' }

context 'when provided as argument to match' do
before { expect(matcher).to receive(:===).with(s(:int, 10)).and_return true } # rubocop:todo RSpec/ExpectInHook

it_behaves_like 'matching'
end

context 'when extra are provided' do
let(:keyword_params) { { foo: matcher, bar: matcher } }

it 'raises an ArgumentError' do
expect { result }.to raise_error(ArgumentError)
end
end

context 'when not provided' do
let(:keyword_params) { {} }

it 'raises an ArgumentError' do
expect { result }.to raise_error(ArgumentError)
end
end
end

context 'in a nested sequence' do
let(:pattern) { '(send (send _ %2) %1)' }
let(:params) { %i[inc dec] }
Expand Down Expand Up @@ -1791,14 +1860,27 @@ def withargs(foo, bar, qux)
end)
end

let(:keyword_defaults) { {} }
let(:method_name) { :my_matcher }
let(:line_no) { __LINE__ + 2 }
let(:defined_class) do
MyClass.public_send helper_name, method_name, pattern
MyClass.public_send helper_name, method_name, pattern, **keyword_defaults
MyClass
end
let(:ruby) { ':hello' }
let(:result) { defined_class.new.send(method_name, node, *params) }
let(:result) do
if keyword_params.empty? # Avoid bug in Ruby < 2.7
defined_class.new.send(method_name, node, *params)
else
defined_class.new.send(method_name, node, *params, **keyword_params)
end
end

if Set[1] === 1 # rubocop:disable Style/CaseEquality
let(:hello_matcher) { Set[:hello, :foo] }
else
let(:hello_matcher) { Set[:hello, :foo].method(:include?).to_proc }
end

context 'with a pattern without captures' do
let(:pattern) { '(sym _)' }
Expand Down Expand Up @@ -1932,6 +2014,44 @@ def withargs(foo, bar, qux)
expect(result.is_a?(Enumerator)).to be(true)
expect(result.to_a).to match_array %i[hello world]
end

context 'when the pattern contains keyword_params' do
let(:pattern) { '(sym $%foo)' }
let(:keyword_params) { { foo: hello_matcher } }

it 'returns an enumerator yielding the captures' do
expect(result.is_a?(Enumerator)).to be(true)
expect(result.to_a).to match_array %i[hello]
end

# rubocop:disable RSpec/NestedGroups
context 'when helper is called with default keyword_params' do
let(:keyword_defaults) { { foo: :world } }

it 'is overriden when calling the matcher' do
expect(result.is_a?(Enumerator)).to be(true)
expect(result.to_a).to match_array %i[hello]
end

context 'and no value is given to the matcher' do
let(:keyword_params) { {} }

it 'uses the defaults' do
expect(result.is_a?(Enumerator)).to be(true)
expect(result.to_a).to match_array %i[world]
end
end

context 'some defaults are not params' do
let(:keyword_defaults) { { bar: :world } }

it 'raises an error' do
expect { result }.to raise_error(ArgumentError)
end
end
end
# rubocop:enable RSpec/NestedGroups
end
end

context 'when called on non-matching code' do
Expand Down

0 comments on commit 5d1659e

Please sign in to comment.