/
__init__.py
173 lines (143 loc) · 4.85 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Plot multidimensional values against each other.
"""
from numpy import array as _array
from matplotlib.colors import Normalize as _Normalize
from matplotlib.gridspec import GridSpec as _GridSpec
from matplotlib.pyplot import figure as _figure
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
def spaceplots(
inputs, outputs, input_names=None, output_names=None, limits=None, **kwargs
):
"""
Plot multidimensional values against each other.
"""
num_samples, num_inputs = inputs.shape
if input_names is not None:
if len(input_names) != num_inputs:
raise RuntimeError("Input data and names don't match")
else:
input_names = [None] * num_inputs
out_num_samples, num_outputs = outputs.shape
if out_num_samples != num_samples:
raise RuntimeError("Inputs and outputs don't match")
if output_names is not None:
if len(output_names) != num_outputs:
raise RuntimeError("Output data and names don't match")
else:
output_names = [None] * num_outputs
if limits is not None:
if limits.shape[1] != 2:
raise RuntimeError(
"There must be a upper and lower limit for each output"
)
if limits.shape[0] != num_outputs:
raise RuntimeError("Output data and limits don't match")
else:
limits = [[None, None]] * num_outputs
for out_index in range(num_outputs):
yield _subspace_plot(
inputs, outputs[:, out_index], input_names=input_names,
output_name=output_names[out_index],
min_output=limits[out_index][0],
max_output=limits[out_index][1], **kwargs
)
def _setup_axes(
*, input_names, histogram_labels=False, constrained_layout=True
):
"""
Setup axes
"""
num_inputs = len(input_names)
fig = _figure(constrained_layout=constrained_layout)
axes = _array(
[[None] * num_inputs for _ in range(num_inputs)],
dtype=object
)
grid = _GridSpec(
nrows=num_inputs, ncols=num_inputs, figure=fig,
)
common_tick_args = dict(
top=True,
bottom=True,
left=True,
right=True,
direction='in',
)
for i in range(num_inputs):
axes[i, i] = fig.add_subplot(grid[i, i])
axes[i, i].tick_params(
labelbottom=False, labelleft=False, labelright=histogram_labels,
**common_tick_args
)
for y in range(num_inputs):
for x in range(y):
plot_args = dict(sharex=axes[x, x])
if y != 0:
plot_args["sharey"] = axes[y, 0]
axes[y, x] = fig.add_subplot(grid[y, x], **plot_args)
if x != 0:
axes[y, x].tick_params(labelleft=False, **common_tick_args)
else:
axes[y, x].set_ylabel(input_names[y])
if y != num_inputs - 1:
axes[y, x].tick_params(labelbottom=False, **common_tick_args)
else:
axes[y, x].set_xlabel(input_names[x])
return fig, axes, grid
def _subspace_plot(
inputs, output, *, input_names, output_name, scatter_args=None,
histogram_args=None, min_output=None, max_output=None
):
"""
Do actual plotting
"""
if scatter_args is None:
scatter_args = {}
if histogram_args is None:
histogram_args = {}
if min_output is None:
min_output = min(output)
if max_output is None:
max_output = max(output)
# see https://matplotlib.org/examples/pylab_examples/multi_image.html
_, num_inputs = inputs.shape
fig, axes, grid = _setup_axes(input_names=input_names)
if output_name is not None:
fig.suptitle(output_name)
norm = _Normalize(min_output, max_output)
hist_plots = []
for i in range(num_inputs):
hist_plots.append(_plot_hist(
inputs[:, i], axis=axes[i][i], **histogram_args
))
scatter_plots = []
scatter_plots_grid = []
for y_index in range(num_inputs):
scatter_plots_grid.append([])
for x_index in range(y_index):
sc_plot = _plot_scatter(
x=inputs[:, x_index], y=inputs[:, y_index], z=output,
axis=axes[y_index][x_index], # check order
norm=norm, **scatter_args
)
scatter_plots.append(sc_plot)
scatter_plots_grid[y_index].append(sc_plot)
cbar_ax = fig.add_subplot(grid[0, 1:])
fig.colorbar(
scatter_plots[0], cax=cbar_ax, orientation='horizontal',
)
cbar_ax.set_aspect(1/20)
return fig
def _plot_hist(values, *, axis, **kwargs):
"""
Plot histogram subplot
"""
return axis.hist(values, **kwargs)
def _plot_scatter(*, x, y, z, axis, norm, **kwargs):
"""
Plot scatter subplot
"""
return axis.scatter(x=x, y=y, c=z, norm=norm, **kwargs)