diff --git a/joblib/hashing.py b/joblib/hashing.py index 24aeb559d..67bcefd4e 100644 --- a/joblib/hashing.py +++ b/joblib/hashing.py @@ -15,9 +15,48 @@ import io import decimal +try: + import xxhash +except ImportError: + xxhash = None + Pickler = pickle._Pickler +_HASHES = {} + + +def register_hash(hash_name, hash, force=False): + """Register a new hash function. + + Parameters + ----------- + hash_name: str. + The name of the hash function. + hash: + A hashlib compliant hash function. + """ + global _HASHES + if not isinstance(hash_name, str): + raise ValueError("Hash name should be a string, " + "'{}' given.".format(hash_name)) + + if not hasattr(hash(), 'update') or not hasattr(hash(), 'hexdigest'): + raise ValueError("Hash function instance must implement `update` " + "and `hexdigest` methods.") + + if hash_name in _HASHES and not force: + raise ValueError("Hash function '{}' already registered." + .format(hash_name)) + + _HASHES[hash_name] = hash + + +register_hash('md5', hashlib.md5) +register_hash('sha1', hashlib.sha1) +if xxhash: + register_hash('xxh3_64', xxhash.xxh3_64) + class _ConsistentSet(object): """ Class used to ensure the hash of Sets is preserved @@ -56,7 +95,7 @@ def __init__(self, hash_name='md5'): protocol = 3 Pickler.__init__(self, self.stream, protocol=protocol) # Initialise the hash obj - self._hash = hashlib.new(hash_name) + self._hash = _HASHES[hash_name]() def hash(self, obj, return_digest=True): try: @@ -254,7 +293,7 @@ def hash(obj, hash_name='md5', coerce_mmap=False): coerce_mmap: boolean Make no difference between np.memmap and np.ndarray """ - valid_hash_names = ('md5', 'sha1') + valid_hash_names = tuple(_HASHES.keys()) if hash_name not in valid_hash_names: raise ValueError("Valid options for 'hash_name' are {}. " "Got hash_name={!r} instead." diff --git a/joblib/memory.py b/joblib/memory.py index b660f1479..63d873dc1 100644 --- a/joblib/memory.py +++ b/joblib/memory.py @@ -402,6 +402,11 @@ class MemorizedFunc(Logger): of compression. Note that compressed arrays cannot be read by memmapping. + hash_name: {'md5', 'sha1'}, optional + The name of the hash function to use to hash arguments. + Defaults to 'md5'. Additionally, if `xxhash` is installed + 'xxh3_64' is valid. + verbose: int, optional The verbosity flag, controls messages that are issued as the function is evaluated. @@ -411,10 +416,16 @@ class MemorizedFunc(Logger): # ------------------------------------------------------------------------ def __init__(self, func, location, backend='local', ignore=None, - mmap_mode=None, compress=False, verbose=1, timestamp=None): + mmap_mode=None, compress=False, hash_name='md5', + verbose=1, timestamp=None): Logger.__init__(self) self.mmap_mode = mmap_mode self.compress = compress + if hash_name not in hashing._HASHES: + raise ValueError("Valid options for 'hash_name' are {}. " + "Got hash_name={!r} instead." + .format(hash_name, hash_name)) + self.hash_name = hash_name self.func = func if ignore is None: @@ -630,7 +641,8 @@ def check_call_in_cache(self, *args, **kwargs): def _get_argument_hash(self, *args, **kwargs): return hashing.hash(filter_args(self.func, self.ignore, args, kwargs), - coerce_mmap=(self.mmap_mode is not None)) + coerce_mmap=(self.mmap_mode is not None), + hash_name=self.hash_name) def _get_output_identifiers(self, *args, **kwargs): """Return the func identifier and input parameter hash of a result.""" @@ -893,6 +905,11 @@ class Memory(Logger): of compression. Note that compressed arrays cannot be read by memmapping. + hash_name: {'md5', 'sha1'}, optional + The name of the hash function to use to hash arguments. + Defaults to 'md5'. Additionally, if `xxhash` is installed + 'xxh3_64' is valid. + verbose: int, optional Verbosity flag, controls the debug messages that are issued as functions are evaluated. @@ -914,8 +931,8 @@ class Memory(Logger): # ------------------------------------------------------------------------ def __init__(self, location=None, backend='local', cachedir=None, - mmap_mode=None, compress=False, verbose=1, bytes_limit=None, - backend_options=None): + mmap_mode=None, compress=False, hash_name='md5', verbose=1, + bytes_limit=None, backend_options=None): # XXX: Bad explanation of the None value of cachedir Logger.__init__(self) self._verbose = verbose @@ -924,6 +941,11 @@ def __init__(self, location=None, backend='local', cachedir=None, self.bytes_limit = bytes_limit self.backend = backend self.compress = compress + if hash_name not in hashing._HASHES: + raise ValueError("Valid options for 'hash_name' are {}. " + "Got hash_name={!r} instead." + .format(hash_name, hash_name)) + self.hash_name = hash_name if backend_options is None: backend_options = {} self.backend_options = backend_options @@ -1011,6 +1033,7 @@ def cache(self, func=None, ignore=None, verbose=None, mmap_mode=False): backend=self.backend, ignore=ignore, mmap_mode=mmap_mode, compress=self.compress, + hash_name=self.hash_name, verbose=verbose, timestamp=self.timestamp) def clear(self, warn=True): diff --git a/joblib/test/test_hashing.py b/joblib/test/test_hashing.py index 5682d8f21..35461a850 100644 --- a/joblib/test/test_hashing.py +++ b/joblib/test/test_hashing.py @@ -20,7 +20,7 @@ from decimal import Decimal import pytest -from joblib.hashing import hash +from joblib.hashing import hash, register_hash, _HASHES from joblib.func_inspect import filter_args from joblib.memory import Memory from joblib.testing import raises, skipif, fixture, parametrize @@ -495,3 +495,23 @@ def test_wrong_hash_name(): with raises(ValueError, match=msg): data = {'foo': 'bar'} hash(data, hash_name='invalid') + + +def test_right_regist_hash(): + hash_name = 'my_hash' + assert hash_name not in _HASHES + register_hash(hash_name, hashlib.sha256) + assert _HASHES[hash_name] == hashlib.sha256 + + +def test_wrong_register_hash(): + with raises(ValueError, match="Hash name should be a string"): + register_hash(0, hashlib.md5) + + with raises( + ValueError, match="Hash function instance must implement"): + register_hash('test_hash', int) + + with raises( + ValueError, match="Hash function 'md5' already registered."): + register_hash('md5', hashlib.md5) diff --git a/joblib/test/test_memory.py b/joblib/test/test_memory.py index 6f749667d..87f781368 100644 --- a/joblib/test/test_memory.py +++ b/joblib/test/test_memory.py @@ -386,6 +386,36 @@ def test_argument_change(tmpdir): assert func() == 1 +def test_memory_invalid_hash_name(tmpdir): + with raises(ValueError, match="Valid options for 'hash_name' are"): + Memory(tmpdir.strpath, hash_name='not_valid') + + +def test_memorized_func_invalid_hash_name(tmpdir): + with raises(ValueError, match="Valid options for 'hash_name' are"): + MemorizedFunc(int, tmpdir.strpath, hash_name='not_valid') + + +def test_memory_custom_hash(tmpdir): + " Test memory with a function with numpy arrays." + accumulator = list() + + def n(ls=None): + accumulator.append(1) + return ls + + memory = Memory(location=tmpdir.strpath, + verbose=0, hash_name='sha1') + cached_n = memory.cache(n) + + vals = (1, 2, 3) + for i in range(3): + a = vals[i - 1] + for _ in range(3): + assert cached_n(a) == a + assert len(accumulator) == i + 1 + + @with_numpy @parametrize('mmap_mode', [None, 'r']) def test_memory_numpy(tmpdir, mmap_mode):