forked from shap/shap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_optimizer.py
92 lines (78 loc) · 3.55 KB
/
_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import queue
import numpy as np
import warnings
import copy
from ._action import Action
from ..utils._exceptions import InvalidAction, ConvergenceError
class ActionOptimizer:
def __init__(self, model, actions):
self.model = model
warnings.warn(
"Note that ActionOptimizer is still in an alpha state and is subjust to API changes."
)
# actions go into mutually exclusive groups
self.action_groups = []
for group in actions:
if issubclass(type(group), Action):
group._group_index = len(self.action_groups)
group._grouped_index = 0
self.action_groups.append([copy.copy(group)])
elif issubclass(type(group), list):
group = sorted([copy.copy(v) for v in group], key=lambda a: a.cost)
for i, v in enumerate(group):
v._group_index = len(self.action_groups)
v._grouped_index = i
self.action_groups.append(group)
else:
raise InvalidAction(
"A passed action was not an Action or list of actions!"
)
def __call__(self, *args, max_evals=10000):
# init our queue with all the least costly actions
q = queue.PriorityQueue()
for i in range(len(self.action_groups)):
group = self.action_groups[i]
q.put((group[0].cost, [group[0]]))
nevals = 0
while not q.empty():
# see if we have exceeded our runtime budget
nevals += 1
if nevals > max_evals:
raise ConvergenceError(
f"Failed to find a solution with max_evals={max_evals}! Try reducing the number of actions or increasing max_evals."
)
# get the next cheapest set of actions we can do
cost, actions = q.get()
# apply those actions
args_tmp = copy.deepcopy(args)
for a in actions:
a(*args_tmp)
# if the model is now satisfied we are done!!
v = self.model(*args_tmp)
if v:
return actions
# if not then we add all possible follow-on actions to our queue
else:
for i in range(len(self.action_groups)):
group = self.action_groups[i]
# look to to see if we already have a action from this group, if so we need to
# move to a more expensive action in the same group
next_ind = 0
prev_in_group = -1
for j, a in enumerate(actions):
if a._group_index == i:
next_ind = max(next_ind, a._grouped_index + 1)
prev_in_group = j
# we are adding a new action type
if prev_in_group == -1:
new_actions = actions + [group[next_ind]]
# we are moving from one action to a more expensive one in the same group
elif next_ind < len(group):
new_actions = copy.copy(actions)
new_actions[prev_in_group] = group[next_ind]
# we don't have a more expensive action left in this group
else:
new_actions = None
# add the new option to our queue
if new_actions is not None:
q.put((sum([a.cost for a in new_actions]), new_actions))