Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to use custom hash functions in hashing.hash and Memory #1232

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 41 additions & 2 deletions joblib/hashing.py
Expand Up @@ -15,9 +15,48 @@
import io
import decimal

try:
import xxhash
except ImportError:
xxhash = None
Comment on lines +18 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably remove this part. The rational is that as this is not the default, users would mainly discover this through the doc, and we could explain how to register this when they create a Memory object?

Else, this results in an extra lib import on all processes, which might be costly ( not sure how heavy this lib is?)



Pickler = pickle._Pickler

_HASHES = {}


def register_hash(hash_name, hash, force=False):
"""Register a new hash function.

Parameters
-----------
hash_name: str.
The name of the hash function.
hash:
A hashlib compliant hash function.
"""
global _HASHES
if not isinstance(hash_name, str):
raise ValueError("Hash name should be a string, "
"'{}' given.".format(hash_name))

if not hasattr(hash(), 'update') or not hasattr(hash(), 'hexdigest'):
raise ValueError("Hash function instance must implement `update` "
"and `hexdigest` methods.")

if hash_name in _HASHES and not force:
raise ValueError("Hash function '{}' already registered."
.format(hash_name))

_HASHES[hash_name] = hash


register_hash('md5', hashlib.md5)
register_hash('sha1', hashlib.sha1)
if xxhash:
register_hash('xxh3_64', xxhash.xxh3_64)


class _ConsistentSet(object):
""" Class used to ensure the hash of Sets is preserved
Expand Down Expand Up @@ -56,7 +95,7 @@ def __init__(self, hash_name='md5'):
protocol = 3
Pickler.__init__(self, self.stream, protocol=protocol)
# Initialise the hash obj
self._hash = hashlib.new(hash_name)
self._hash = _HASHES[hash_name]()

def hash(self, obj, return_digest=True):
try:
Expand Down Expand Up @@ -254,7 +293,7 @@ def hash(obj, hash_name='md5', coerce_mmap=False):
coerce_mmap: boolean
Make no difference between np.memmap and np.ndarray
"""
valid_hash_names = ('md5', 'sha1')
valid_hash_names = tuple(_HASHES.keys())
if hash_name not in valid_hash_names:
raise ValueError("Valid options for 'hash_name' are {}. "
"Got hash_name={!r} instead."
Expand Down
31 changes: 27 additions & 4 deletions joblib/memory.py
Expand Up @@ -402,6 +402,11 @@ class MemorizedFunc(Logger):
of compression. Note that compressed arrays cannot be
read by memmapping.

hash_name: {'md5', 'sha1'}, optional
The name of the hash function to use to hash arguments.
Defaults to 'md5'. Additionally, if `xxhash` is installed
'xxh3_64' is valid.

verbose: int, optional
The verbosity flag, controls messages that are issued as
the function is evaluated.
Expand All @@ -411,10 +416,16 @@ class MemorizedFunc(Logger):
# ------------------------------------------------------------------------

def __init__(self, func, location, backend='local', ignore=None,
mmap_mode=None, compress=False, verbose=1, timestamp=None):
mmap_mode=None, compress=False, hash_name='md5',
verbose=1, timestamp=None):
Logger.__init__(self)
self.mmap_mode = mmap_mode
self.compress = compress
if hash_name not in hashing._HASHES:
raise ValueError("Valid options for 'hash_name' are {}. "
"Got hash_name={!r} instead."
.format(hash_name, hash_name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.format(hash_name, hash_name))
.format(hashing._HASHES, hash_name))

self.hash_name = hash_name
self.func = func

if ignore is None:
Expand Down Expand Up @@ -630,7 +641,8 @@ def check_call_in_cache(self, *args, **kwargs):

def _get_argument_hash(self, *args, **kwargs):
return hashing.hash(filter_args(self.func, self.ignore, args, kwargs),
coerce_mmap=(self.mmap_mode is not None))
coerce_mmap=(self.mmap_mode is not None),
hash_name=self.hash_name)

def _get_output_identifiers(self, *args, **kwargs):
"""Return the func identifier and input parameter hash of a result."""
Expand Down Expand Up @@ -893,6 +905,11 @@ class Memory(Logger):
of compression. Note that compressed arrays cannot be
read by memmapping.

hash_name: {'md5', 'sha1'}, optional
The name of the hash function to use to hash arguments.
Defaults to 'md5'. Additionally, if `xxhash` is installed
'xxh3_64' is valid.

verbose: int, optional
Verbosity flag, controls the debug messages that are issued
as functions are evaluated.
Expand All @@ -914,8 +931,8 @@ class Memory(Logger):
# ------------------------------------------------------------------------

def __init__(self, location=None, backend='local', cachedir=None,
mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
backend_options=None):
mmap_mode=None, compress=False, hash_name='md5', verbose=1,
bytes_limit=None, backend_options=None):
# XXX: Bad explanation of the None value of cachedir
Logger.__init__(self)
self._verbose = verbose
Expand All @@ -924,6 +941,11 @@ def __init__(self, location=None, backend='local', cachedir=None,
self.bytes_limit = bytes_limit
self.backend = backend
self.compress = compress
if hash_name not in hashing._HASHES:
raise ValueError("Valid options for 'hash_name' are {}. "
"Got hash_name={!r} instead."
.format(hash_name, hash_name))
self.hash_name = hash_name
if backend_options is None:
backend_options = {}
self.backend_options = backend_options
Expand Down Expand Up @@ -1011,6 +1033,7 @@ def cache(self, func=None, ignore=None, verbose=None, mmap_mode=False):
backend=self.backend,
ignore=ignore, mmap_mode=mmap_mode,
compress=self.compress,
hash_name=self.hash_name,
verbose=verbose, timestamp=self.timestamp)

def clear(self, warn=True):
Expand Down
22 changes: 21 additions & 1 deletion joblib/test/test_hashing.py
Expand Up @@ -20,7 +20,7 @@
from decimal import Decimal
import pytest

from joblib.hashing import hash
from joblib.hashing import hash, register_hash, _HASHES
from joblib.func_inspect import filter_args
from joblib.memory import Memory
from joblib.testing import raises, skipif, fixture, parametrize
Expand Down Expand Up @@ -495,3 +495,23 @@ def test_wrong_hash_name():
with raises(ValueError, match=msg):
data = {'foo': 'bar'}
hash(data, hash_name='invalid')


def test_right_regist_hash():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_right_regist_hash():
def test_right_register_hash():

hash_name = 'my_hash'
assert hash_name not in _HASHES
register_hash(hash_name, hashlib.sha256)
assert _HASHES[hash_name] == hashlib.sha256


def test_wrong_register_hash():
with raises(ValueError, match="Hash name should be a string"):
register_hash(0, hashlib.md5)

with raises(
ValueError, match="Hash function instance must implement"):
register_hash('test_hash', int)

with raises(
ValueError, match="Hash function 'md5' already registered."):
register_hash('md5', hashlib.md5)
30 changes: 30 additions & 0 deletions joblib/test/test_memory.py
Expand Up @@ -386,6 +386,36 @@ def test_argument_change(tmpdir):
assert func() == 1


def test_memory_invalid_hash_name(tmpdir):
with raises(ValueError, match="Valid options for 'hash_name' are"):
Memory(tmpdir.strpath, hash_name='not_valid')


def test_memorized_func_invalid_hash_name(tmpdir):
with raises(ValueError, match="Valid options for 'hash_name' are"):
MemorizedFunc(int, tmpdir.strpath, hash_name='not_valid')


def test_memory_custom_hash(tmpdir):
" Test memory with a function with numpy arrays."
accumulator = list()

def n(ls=None):
accumulator.append(1)
return ls

memory = Memory(location=tmpdir.strpath,
verbose=0, hash_name='sha1')
cached_n = memory.cache(n)

vals = (1, 2, 3)
for i in range(3):
a = vals[i - 1]
for _ in range(3):
assert cached_n(a) == a
assert len(accumulator) == i + 1


@with_numpy
@parametrize('mmap_mode', [None, 'r'])
def test_memory_numpy(tmpdir, mmap_mode):
Expand Down