/
misc.py
393 lines (314 loc) · 11.9 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# Copyright (c) 2018 The Regents of the University of Michigan
# All rights reserved.
# This software is licensed under the BSD 3-Clause License.
"""Miscellaneous utility functions."""
import argparse
import logging
import os
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import lru_cache, partial
from itertools import cycle, islice
import cloudpickle
from tqdm.contrib import tmap
from tqdm.contrib.concurrent import process_map, thread_map
try:
# If ipywidgets is installed, use "auto" tqdm to improve notebook support.
# Otherwise, use only text-based progress bars. This workaround can be
# removed after https://github.com/tqdm/tqdm/pull/1218.
import ipywidgets # noqa: F401
except ImportError:
from tqdm import tqdm
else:
from tqdm.auto import tqdm
def _positive_int(value):
"""Parse a command line argument as a positive integer.
Designed to be used in conjunction with :class:`argparse.ArgumentParser`.
Parameters
----------
value : str
The value to parse.
Returns
-------
int
The provided value, cast to an integer.
Raises
------
:class:`argparse.ArgumentTypeError`
If value cannot be cast to an integer or is negative.
"""
try:
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError("Value must be positive.")
except (TypeError, ValueError):
raise argparse.ArgumentTypeError(f"{value} must be a positive integer.")
return ivalue
@contextmanager
def redirect_log(job, filename="run.log", formatter=None, logger=None):
"""Redirect all messages logged via the logging interface to the given file.
This method is a context manager. The logging handler is removed when
exiting the context.
Parameters
----------
job : :class:`signac.contrib.job.Job`
The signac job whose workspace will store the redirected logs.
filename : str
File name of the log. (Default value = "run.log")
formatter : :class:`logging.Formatter`
The logging formatter to use, uses a default formatter if None.
(Default value = None)
logger : :class:`logging.Logger`
The instance of logger to which the new file log handler is added.
Defaults to the default logger returned by :meth:`logging.getLogger` if
this argument is not provided.
"""
if formatter is None:
formatter = logging.Formatter(
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
)
if logger is None:
logger = logging.getLogger()
filehandler = logging.FileHandler(filename=job.fn("run.log"))
filehandler.setFormatter(formatter)
logger.addHandler(filehandler)
try:
yield
finally:
logger.removeHandler(filehandler)
@contextmanager
def add_path_to_environment_pythonpath(path):
"""Insert the provided path into the environment PYTHONPATH variable.
This method is a context manager. It restores the previous PYTHONPATH when
exiting the context.
Parameters
----------
path : str
Path to add to PYTHONPATH.
"""
path = os.path.realpath(path)
pythonpath = os.environ.get("PYTHONPATH")
if pythonpath:
for path_ in pythonpath:
if os.path.isabs(path_) and os.path.realpath(path_) == path:
yield # Path is already in PYTHONPATH, nothing to do here.
return
try:
# Append the current working directory to the PYTHONPATH.
tmp_path = [path] + pythonpath.split(":")
os.environ["PYTHONPATH"] = ":".join(tmp_path)
yield
finally:
os.environ["PYTHONPATH"] = pythonpath
else:
try:
# The PYTHONPATH was previously not set, set to current working directory.
os.environ["PYTHONPATH"] = path
yield
finally:
del os.environ["PYTHONPATH"]
@contextmanager
def add_cwd_to_environment_pythonpath():
"""Add current working directory to PYTHONPATH."""
with add_path_to_environment_pythonpath(os.getcwd()):
yield
@contextmanager
def switch_to_directory(root=None):
"""Temporarily switch into the given root directory (if not None).
This method is a context manager. It switches to the previous working
directory when exiting the context.
Parameters
----------
root : str
Current working directory to use for within the context. (Default value
= None)
"""
if root is None:
yield
else:
cwd = os.getcwd()
try:
os.chdir(root)
yield
finally:
os.chdir(cwd)
class TrackGetItemDict(dict):
"""A dict that tracks which keys have been accessed.
Keys accessed with ``__getitem__`` are stored in the property
:attr:`~.keys_used`.
"""
def __init__(self, *args, **kwargs):
self._keys_used = set()
super().__init__(*args, **kwargs)
def __getitem__(self, key):
"""Get item by key."""
self._keys_used.add(key)
return super().__getitem__(key)
def get(self, key, default=None):
"""Return the value for key if key is in the dictionary, else default.
If default is not given, it defaults to ``None``, so that this method
never raises a :class:`KeyError`.
"""
self._keys_used.add(key)
return super().get(key, default)
@property
def keys_used(self):
"""Return all keys that have been accessed."""
return self._keys_used.copy()
def roundrobin(*iterables):
"""Round robin iterator.
Cycles through a sequence of iterables, taking one item from each iterable
until all iterables are exhausted.
"""
# From: https://docs.python.org/3/library/itertools.html#itertools-recipes
# roundrobin('ABC', 'D', 'EF') --> A D E B F C
# Recipe credited to George Sakkis
num_active = len(iterables)
nexts = cycle(iter(it).__next__ for it in iterables)
while num_active:
try:
for next in nexts:
yield next()
except StopIteration:
# Remove the iterator we just exhausted from the cycle.
num_active -= 1
nexts = cycle(islice(nexts, num_active))
class _hashable_dict(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))
def _to_hashable(obj):
"""Create a hash of passed type.
Parameters
----------
obj
Object to make hashable. Lists are converted to tuples, and hashes are
defined for dicts.
Returns
-------
object
Hashable object.
"""
if type(obj) is list:
return tuple(_to_hashable(_) for _ in obj)
elif type(obj) is dict:
return _hashable_dict(obj)
else:
return obj
def _cached_partial(func, *args, maxsize=None, **kwargs):
r"""Cache the results of a partial.
Useful for wrapping functions that must only be evaluated lazily, one time.
Parameters
----------
func : callable
The function to call.
\*args
Positional arguments bound to the function.
maxsize : int
The maximum size of the LRU cache, or None for no limit. (Default value
= None)
\*\*kwargs
Keyword arguments bound to the function.
Returns
-------
callable
Function with bound arguments and cached return values.
"""
return lru_cache(maxsize=maxsize)(partial(func, *args, **kwargs))
class _bidict(MutableMapping):
r"""A bidirectional dictionary.
The attribute ``inverse`` contains the inverse mapping, where the inverse
values are stored as a :class:`list` of keys with that value.
Both keys and values must be hashable.
A key is associated with exactly one value.
A value is associated with one or more keys.
The inverse mapping should not be modified directly.
The list of inverse values (keys) must be insertion-ordered.
"""
# Based on: https://stackoverflow.com/a/21894086
def __init__(self, *args, **kwargs):
self._data = dict(*args, **kwargs)
self.inverse = {}
for key, value in self._data.items():
self.inverse.setdefault(value, []).append(key)
def __getitem__(self, key):
"""Get a value from the provided key."""
return self._data[key]
def __setitem__(self, key, value):
"""Assign a value to the provided key."""
if key in self._data:
old_value = self._data[key]
self.inverse[old_value].remove(key)
if len(self.inverse[old_value]) == 0:
del self.inverse[old_value]
self._data[key] = value
self.inverse.setdefault(value, []).append(key)
def __delitem__(self, key):
"""Delete the provided key."""
value = self._data[key]
self.inverse[value].remove(key)
if len(self.inverse[value]) == 0:
del self.inverse[value]
del self._data[key]
def __iter__(self):
yield from self._data
def __len__(self):
return len(self._data)
def _run_cloudpickled_func(func, *args):
"""Execute a cloudpickled function.
The set of functions that can be pickled by the built-in pickle module is
very limited, which prevents the usage of various useful cases such as
locally-defined functions or functions that internally call class methods.
This function circumvents that difficulty by allowing the user to pickle
the function object a priori and bind it as the first argument to a partial
application of this function. All subsequent arguments are transparently
passed through.
"""
unpickled_func = cloudpickle.loads(func)
args = list(map(cloudpickle.loads, args))
return unpickled_func(*args)
def _get_parallel_executor(parallelization="none"):
"""Get an executor for the desired parallelization strategy.
This executor shows a progress bar while executing a function over an
iterable in parallel. The returned callable has signature ``func,
iterable, **kwargs``. The iterable must have a length (generators are not
supported). The keyword argument ``chunksize`` is used for chunking the
iterable in supported parallelization modes
(see :meth:`concurrent.futures.Executor.map`). All other ``**kwargs`` are
passed to the tqdm progress bar.
Parameters
----------
parallelization : str
Parallelization mode. Allowed values are "thread", "process", or
"none". (Default value = "none")
Returns
-------
callable
A callable with signature ``func, iterable, **kwargs``.
"""
if parallelization == "thread":
def parallel_executor(func, iterable, **kwargs):
return thread_map(func, iterable, tqdm_class=tqdm, **kwargs)
elif parallelization == "process":
def parallel_executor(func, iterable, **kwargs):
# The tqdm progress bar requires a total. We compute the total in
# advance because a map iterable (which has no total) is passed to
# process_map.
if "total" not in kwargs:
kwargs["total"] = len(iterable)
return process_map(
# The top-level function called on each process cannot be a
# local function, it must be a module-level function. Creating
# a partial here allows us to use the passed function "func"
# regardless of whether it is a local function.
partial(_run_cloudpickled_func, cloudpickle.dumps(func)),
map(cloudpickle.dumps, iterable),
tqdm_class=tqdm,
**kwargs,
)
else:
def parallel_executor(func, iterable, **kwargs):
if "chunksize" in kwargs:
# Chunk size only applies to thread/process parallel executors
del kwargs["chunksize"]
return list(tmap(func, iterable, tqdm_class=tqdm, **kwargs))
return parallel_executor