/
json_parser.py
180 lines (149 loc) · 5.82 KB
/
json_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
'''Demonstration for parsing JSON tree model file generated by XGBoost. The
support is experimental, output schema is subject to change in the future.
'''
import json
import argparse
class Tree:
'''A tree built by XGBoost.'''
# Index into node array
_left = 0
_right = 1
_parent = 2
_ind = 3
_cond = 4
_default_left = 5
# Index into stat array
_loss_chg = 0
_sum_hess = 1
_base_weight = 2
_child_cnt = 3
def __init__(self, tree_id: int, nodes, stats):
self.tree_id = tree_id
self.nodes = nodes
self.stats = stats
def loss_change(self, node_id: int):
'''Loss gain of a node.'''
return self.stats[node_id][self._loss_chg]
def sum_hessian(self, node_id: int):
'''Sum Hessian of a node.'''
return self.stats[node_id][self._sum_hess]
def base_weight(self, node_id: int):
'''Base weight of a node.'''
return self.stats[node_id][self._base_weight]
def num_children(self, node_id: int):
'''Number of children of a node.'''
return self.stats[node_id][self._child_cnt]
def split_index(self, node_id: int):
'''Split feature index of node.'''
return self.nodes[node_id][self._ind]
def split_condition(self, node_id: int):
'''Split value of a node.'''
return self.nodes[node_id][self._cond]
def parent(self, node_id: int):
'''Parent ID of a node.'''
return self.nodes[node_id][self._parent]
def left_child(self, node_id: int):
'''Left child ID of a node.'''
return self.nodes[node_id][self._left]
def right_child(self, node_id: int):
'''Right child ID of a node.'''
return self.nodes[node_id][self._right]
def is_leaf(self, node_id: int):
'''Whether a node is leaf.'''
return self.nodes[node_id][self._left] == -1
def is_deleted(self, node_id: int):
'''Whether a node is deleted.'''
# std::numeric_limits<uint32_t>::max()
return self.nodes[node_id][self._ind] == 4294967295
def __str__(self):
stacks = [0]
nodes = []
while stacks:
node = {}
nid = stacks.pop()
node['node id'] = nid
node['gain'] = self.loss_change(nid)
node['cover'] = self.sum_hessian(nid)
nodes.append(node)
if not self.is_leaf(nid) and not self.is_deleted(nid):
left = self.left_child(nid)
right = self.right_child(nid)
stacks.append(left)
stacks.append(right)
string = '\n'.join(map(lambda x: ' ' + str(x), nodes))
return string
class Model:
'''Gradient boosted tree model.'''
def __init__(self, model: dict):
'''Construct the Model from JSON object.
parameters
----------
m: A dictionary loaded by json
'''
# Basic property of a model
self.learner_model_shape = model['learner']['learner_model_param']
self.num_output_group = int(self.learner_model_shape['num_class'])
self.num_feature = int(self.learner_model_shape['num_feature'])
self.base_score = float(self.learner_model_shape['base_score'])
# A field encoding which output group a tree belongs
self.tree_info = model['learner']['gradient_booster']['model'][
'tree_info']
model_shape = model['learner']['gradient_booster']['model'][
'gbtree_model_param']
# JSON representation of trees
j_trees = model['learner']['gradient_booster']['model']['trees']
# Load the trees
self.num_trees = int(model_shape['num_trees'])
self.leaf_size = int(model_shape['size_leaf_vector'])
# Right now XGBoost doesn't support vector leaf yet
assert self.leaf_size == 0, str(self.leaf_size)
trees = []
for i in range(self.num_trees):
tree = j_trees[i]
tree_id = int(tree['id'])
assert tree_id == i, (tree_id, i)
# properties
left_children = tree['left_children']
right_children = tree['right_children']
parents = tree['parents']
split_conditions = tree['split_conditions']
split_indices = tree['split_indices']
default_left = tree['default_left']
# stats
base_weights = tree['base_weights']
loss_changes = tree['loss_changes']
sum_hessian = tree['sum_hessian']
leaf_child_counts = tree['leaf_child_counts']
stats = []
nodes = []
# We resemble the structure used inside XGBoost, which is similar
# to adjacency list.
for node_id in range(len(left_children)):
nodes.append([
left_children[node_id], right_children[node_id],
parents[node_id], split_indices[node_id],
split_conditions[node_id], default_left[node_id]
])
stats.append([
loss_changes[node_id], sum_hessian[node_id],
base_weights[node_id], leaf_child_counts[node_id]
])
tree = Tree(tree_id, nodes, stats)
trees.append(tree)
self.trees = trees
def print_model(self):
for i, tree in enumerate(self.trees):
print('tree_id:', i)
print(tree)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demonstration for loading and printing XGBoost model.')
parser.add_argument('--model',
type=str,
required=True,
help='Path to JSON model file.')
args = parser.parse_args()
with open(args.model, 'r') as fd:
model = json.load(fd)
model = Model(model)
model.print_model()