forked from rubocop/rubocop
/
multiple_comparison.rb
127 lines (106 loc) · 3.63 KB
/
multiple_comparison.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# frozen_string_literal: true
module RuboCop
module Cop
module Style
# This cop checks against comparing a variable with multiple items, where
# `Array#include?` could be used instead to avoid code repetition.
# It accepts comparisons of multiple method calls to avoid unnecessary method calls
# by default. It can be configured by `AllowMethodComparison` option.
#
# @example
# # bad
# a = 'a'
# foo if a == 'a' || a == 'b' || a == 'c'
#
# # good
# a = 'a'
# foo if ['a', 'b', 'c'].include?(a)
# foo if a == b.lightweight || a == b.heavyweight
#
# @example AllowMethodComparison: true (default)
# # good
# foo if a == b.lightweight || a == b.heavyweight
#
# @example AllowMethodComparison: false
# # bad
# foo if a == b.lightweight || a == b.heavyweight
#
# # good
# foo if [b.lightweight, b.heavyweight].include?(a)
class MultipleComparison < Base
extend AutoCorrector
MSG = 'Avoid comparing a variable with multiple items ' \
'in a conditional, use `Array#include?` instead.'
def on_new_investigation
@compared_elements = []
end
def on_or(node)
root_of_or_node = root_of_or_node(node)
return unless node == root_of_or_node
return unless nested_variable_comparison?(root_of_or_node)
add_offense(node) do |corrector|
elements = @compared_elements.join(', ')
prefer_method = "[#{elements}].include?(#{variables_in_node(node).first})"
corrector.replace(node, prefer_method)
end
end
private
def_node_matcher :simple_double_comparison?, '(send $lvar :== $lvar)'
def_node_matcher :simple_comparison_lhs?, <<~PATTERN
(send $lvar :== $_)
PATTERN
def_node_matcher :simple_comparison_rhs?, <<~PATTERN
(send $_ :== $lvar)
PATTERN
def nested_variable_comparison?(node)
return false unless nested_comparison?(node)
variables_in_node(node).count == 1
end
def variables_in_node(node)
if node.or_type?
node.node_parts
.flat_map { |node_part| variables_in_node(node_part) }
.uniq
else
variables_in_simple_node(node)
end
end
def variables_in_simple_node(node)
simple_double_comparison?(node) do |var1, var2|
return [variable_name(var1), variable_name(var2)]
end
if (var, obj = simple_comparison_lhs?(node)) || (obj, var = simple_comparison_rhs?(node))
return [] if allow_method_comparison? && obj.send_type?
@compared_elements << obj.source
return [variable_name(var)]
end
[]
end
def variable_name(node)
node.children[0]
end
def nested_comparison?(node)
if node.or_type?
node.node_parts.all? { |node_part| comparison? node_part }
else
false
end
end
def comparison?(node)
simple_comparison_lhs?(node) || simple_comparison_rhs?(node) || nested_comparison?(node)
end
def root_of_or_node(or_node)
return or_node unless or_node.parent
if or_node.parent.or_type?
root_of_or_node(or_node.parent)
else
or_node
end
end
def allow_method_comparison?
cop_config.fetch('AllowMethodComparison', true)
end
end
end
end
end