/
multiple_comparison.rb
94 lines (79 loc) · 2.51 KB
/
multiple_comparison.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# frozen_string_literal: true
module RuboCop
module Cop
module Style
# This cop checks against comparing a variable with multiple items, where
# `Array#include?` could be used instead to avoid code repetition.
# It accepts comparisons of multiple method calls to avoid unnecessary method calls.
#
# @example
# # bad
# a = 'a'
# foo if a == 'a' || a == 'b' || a == 'c'
#
# # good
# a = 'a'
# foo if ['a', 'b', 'c'].include?(a)
# foo if a == b.lightweight || a == b.heavyweight
class MultipleComparison < Base
MSG = 'Avoid comparing a variable with multiple items ' \
'in a conditional, use `Array#include?` instead.'
def on_or(node)
root_of_or_node = root_of_or_node(node)
return unless node == root_of_or_node
return unless nested_variable_comparison?(root_of_or_node)
add_offense(node)
end
private
def_node_matcher :simple_double_comparison?, '(send $lvar :== $lvar)'
def_node_matcher :simple_comparison?, <<~PATTERN
{(send $lvar :== !send)
(send !send :== $lvar)}
PATTERN
def nested_variable_comparison?(node)
return false unless nested_comparison?(node)
variables_in_node(node).count == 1
end
def variables_in_node(node)
if node.or_type?
node.node_parts
.flat_map { |node_part| variables_in_node(node_part) }
.uniq
else
variables_in_simple_node(node)
end
end
def variables_in_simple_node(node)
simple_double_comparison?(node) do |var1, var2|
return [variable_name(var1), variable_name(var2)]
end
simple_comparison?(node) do |var|
return [variable_name(var)]
end
[]
end
def variable_name(node)
node.children[0]
end
def nested_comparison?(node)
if node.or_type?
node.node_parts.all? { |node_part| comparison? node_part }
else
false
end
end
def comparison?(node)
simple_comparison?(node) || nested_comparison?(node)
end
def root_of_or_node(or_node)
return or_node unless or_node.parent
if or_node.parent.or_type?
root_of_or_node(or_node.parent)
else
or_node
end
end
end
end
end
end