forked from rubocop/rubocop
/
multiple_comparison.rb
159 lines (132 loc) · 4.49 KB
/
multiple_comparison.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# frozen_string_literal: true
module RuboCop
module Cop
module Style
# This cop checks against comparing a variable with multiple items, where
# `Array#include?`, `Set#include?` or a `case` could be used instead
# to avoid code repetition.
# It accepts comparisons of multiple method calls to avoid unnecessary method calls
# by default. It can be configured by `AllowMethodComparison` option.
#
# @example
# # bad
# a = 'a'
# foo if a == 'a' || a == 'b' || a == 'c'
#
# # good
# a = 'a'
# foo if ['a', 'b', 'c'].include?(a)
#
# VALUES = Set['a', 'b', 'c'].freeze
# # elsewhere...
# foo if VALUES.include?(a)
#
# case foo
# when 'a', 'b', 'c' then foo
# # ...
# end
#
# # accepted (but consider `case` as above)
# foo if a == b.lightweight || a == b.heavyweight
#
# @example AllowMethodComparison: true (default)
# # good
# foo if a == b.lightweight || a == b.heavyweight
#
# @example AllowMethodComparison: false
# # bad
# foo if a == b.lightweight || a == b.heavyweight
#
# # good
# foo if [b.lightweight, b.heavyweight].include?(a)
class MultipleComparison < Base
extend AutoCorrector
MSG = 'Avoid comparing a variable with multiple items ' \
'in a conditional, use `Array#include?` instead.'
def on_new_investigation
@last_comparison = nil
end
def on_or(node)
reset_comparison if switch_comparison?(node)
root_of_or_node = root_of_or_node(node)
return unless node == root_of_or_node
return unless nested_variable_comparison?(root_of_or_node)
return if @allowed_method_comparison
add_offense(node) do |corrector|
elements = @compared_elements.join(', ')
prefer_method = "[#{elements}].include?(#{variables_in_node(node).first})"
corrector.replace(node, prefer_method)
end
@last_comparison = node
end
private
# @!method simple_double_comparison?(node)
def_node_matcher :simple_double_comparison?, '(send $lvar :== $lvar)'
# @!method simple_comparison_lhs?(node)
def_node_matcher :simple_comparison_lhs?, <<~PATTERN
(send $lvar :== $_)
PATTERN
# @!method simple_comparison_rhs?(node)
def_node_matcher :simple_comparison_rhs?, <<~PATTERN
(send $_ :== $lvar)
PATTERN
def nested_variable_comparison?(node)
return false unless nested_comparison?(node)
variables_in_node(node).count == 1
end
def variables_in_node(node)
if node.or_type?
node.node_parts
.flat_map { |node_part| variables_in_node(node_part) }
.uniq
else
variables_in_simple_node(node)
end
end
def variables_in_simple_node(node)
simple_double_comparison?(node) do |var1, var2|
return [variable_name(var1), variable_name(var2)]
end
if (var, obj = simple_comparison_lhs?(node)) || (obj, var = simple_comparison_rhs?(node))
@allowed_method_comparison = true if allow_method_comparison? && obj.send_type?
@compared_elements << obj.source
return [variable_name(var)]
end
[]
end
def variable_name(node)
node.children[0]
end
def nested_comparison?(node)
if node.or_type?
node.node_parts.all? { |node_part| comparison? node_part }
else
false
end
end
def comparison?(node)
simple_comparison_lhs?(node) || simple_comparison_rhs?(node) || nested_comparison?(node)
end
def root_of_or_node(or_node)
return or_node unless or_node.parent
if or_node.parent.or_type?
root_of_or_node(or_node.parent)
else
or_node
end
end
def switch_comparison?(node)
return true if @last_comparison.nil?
@last_comparison.descendants.none?(node)
end
def reset_comparison
@compared_elements = []
@allowed_method_comparison = false
end
def allow_method_comparison?
cop_config.fetch('AllowMethodComparison', true)
end
end
end
end
end