Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure arrays are bytes aligned in joblib pickles #1254

Merged
merged 31 commits into from Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5b855cc
wip
lesteve Feb 2, 2022
232588b
fix
lesteve Feb 2, 2022
282d953
test
lesteve Feb 2, 2022
6b72bcf
Remove debug print
lesteve Feb 3, 2022
ef5ddcd
fix (peek not supported in io.BytesIO)
lesteve Feb 3, 2022
a851f5c
better variables
lesteve Feb 3, 2022
7cd90ed
handle case where .tell not supported
lesteve Feb 3, 2022
e0a994a
Fix memmap with old pickles loaded in mmap mode
lesteve Feb 3, 2022
c8302dc
Use module variable for array bytes alignment
lesteve Feb 3, 2022
4952732
Fix for Windows
lesteve Feb 4, 2022
36cfa85
Merge branch 'master' of https://github.com/joblib/joblib into memmap…
lesteve Feb 4, 2022
513e0d1
Fix test checking old non-aligned arrays with memmap
lesteve Feb 4, 2022
d7cbd3f
Add align_bytes_array attribute in NumpyArrayWrapper
lesteve Feb 8, 2022
d66f2bd
lint
lesteve Feb 8, 2022
7dab61b
Tweak
lesteve Feb 9, 2022
f943214
trigger CI
lesteve Feb 10, 2022
fe5036d
Trigger CI
lesteve Feb 17, 2022
0522f23
Create a numpy_array_alignment_bytes attribute to allow for future ch…
lesteve Feb 17, 2022
41dc185
Add test for edge case where aligned numpy arrays are read in a file …
lesteve Feb 17, 2022
b2b2f92
Merge branch 'master' into memmap-align
ogrisel Feb 23, 2022
73eb719
Tackled most comments
lesteve Feb 24, 2022
e6409e9
Add warning when loading an old pickle with non memory aligned array
lesteve Feb 24, 2022
a36af7c
Add changelog entry.
lesteve Feb 24, 2022
54c2773
tweak
lesteve Feb 24, 2022
dfccb94
Forgotten f-string
lesteve Feb 24, 2022
236e042
Windows fix
lesteve Feb 24, 2022
3995ed4
Trigger CI
lesteve Feb 24, 2022
4705aad
Apply suggestions from code review
lesteve Feb 25, 2022
2e380bb
tweak
lesteve Feb 25, 2022
eb69803
Update joblib/test/test_numpy_pickle.py
ogrisel Feb 25, 2022
188c089
Trigger CI
lesteve Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 51 additions & 3 deletions joblib/numpy_pickle.py
Expand Up @@ -7,6 +7,7 @@
import pickle
import os
import warnings
import io
from pathlib import Path

from .compressor import lz4, LZ4_NOT_INSTALLED_ERROR
Expand Down Expand Up @@ -39,6 +40,8 @@
###############################################################################
# Utility objects for persistence.

NUMPY_ARRAY_ALIGNMENT_BYTES = 16

lesteve marked this conversation as resolved.
Show resolved Hide resolved

class NumpyArrayWrapper(object):
"""An object to be persisted instead of numpy arrays.
Expand Down Expand Up @@ -70,13 +73,20 @@ class NumpyArrayWrapper(object):
Default: False.
"""

def __init__(self, subclass, shape, order, dtype, allow_mmap=False):
def __init__(self, subclass, shape, order, dtype, allow_mmap=False,
array_bytes_aligned=True):
"""Constructor. Store the useful information for later."""
self.subclass = subclass
self.shape = shape
self.order = order
self.dtype = dtype
self.allow_mmap = allow_mmap
self.array_bytes_aligned = array_bytes_aligned

def _array_bytes_aligned(self):
# joblib <= 1.1 pickles NumpyArrayWrapper instances don't have an
# array_bytes_aligned attribute
return getattr(self, 'array_bytes_aligned', False)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

def write_array(self, array, pickler):
"""Write array bytes to pickler file handle.
Expand All @@ -92,6 +102,13 @@ def write_array(self, array, pickler):
# pickle protocol.
pickle.dump(array, pickler.file_handle, protocol=2)
else:
if self._array_bytes_aligned():
current_pos = pickler.file_handle.tell()
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES
if alignment != 0:
padding = b' ' * (NUMPY_ARRAY_ALIGNMENT_BYTES - alignment)
pickler.file_handle.write(padding)

for chunk in pickler.np.nditer(array,
flags=['external_loop',
'buffered',
Expand All @@ -118,6 +135,21 @@ def read_array(self, unpickler):
# The array contained Python objects. We need to unpickle the data.
array = pickle.load(unpickler.file_handle)
else:
if self._array_bytes_aligned():
try:
current_pos = unpickler.file_handle.tell()
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES
if alignment != 0:
padding_length = (
NUMPY_ARRAY_ALIGNMENT_BYTES - alignment)
unpickler.file_handle.seek(
current_pos + padding_length)
except io.UnsupportedOperation as exc:
raise RuntimeError(
'Trying to read a joblib pickle with bytes aligned '
'numpy arrays in a file_handle '
'that does not support .tell') from exc
lesteve marked this conversation as resolved.
Show resolved Hide resolved

# This is not a real file. We have to read it the
# memory-intensive way.
# crc32 module fails on reads greater than 2 ** 32 bytes,
Expand Down Expand Up @@ -150,7 +182,15 @@ def read_array(self, unpickler):

def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
offset = unpickler.file_handle.tell()
current_pos = unpickler.file_handle.tell()
offset = current_pos
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES

if self._array_bytes_aligned():
if alignment != 0:
padding_length = NUMPY_ARRAY_ALIGNMENT_BYTES - alignment
offset += padding_length

lesteve marked this conversation as resolved.
Show resolved Hide resolved
if unpickler.mmap_mode == 'w+':
unpickler.mmap_mode = 'r+'

Expand Down Expand Up @@ -239,9 +279,17 @@ def _create_array_wrapper(self, array):
order = 'F' if (array.flags.f_contiguous and
not array.flags.c_contiguous) else 'C'
allow_mmap = not self.buffered and not array.dtype.hasobject

try:
self.file_handle.tell()
array_bytes_aligned = True
except io.UnsupportedOperation:
array_bytes_aligned = False

wrapper = NumpyArrayWrapper(type(array),
array.shape, order, array.dtype,
allow_mmap=allow_mmap)
allow_mmap=allow_mmap,
array_bytes_aligned=array_bytes_aligned)

return wrapper

Expand Down
77 changes: 75 additions & 2 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -368,7 +368,7 @@ def test_compressed_pickle_dump_and_load(tmpdir):
assert result == expected


def _check_pickle(filename, expected_list):
def _check_pickle(filename, expected_list, mmap_mode=None):
"""Helper function to test joblib pickle content.

Note: currently only pickles containing an iterable are supported
Expand All @@ -388,7 +388,7 @@ def _check_pickle(filename, expected_list):
warnings.filterwarnings(
'ignore', module='numpy',
message='The compiler package is deprecated')
result_list = numpy_pickle.load(filename)
result_list = numpy_pickle.load(filename, mmap_mode=mmap_mode)
filename_base = os.path.basename(filename)
expected_nb_warnings = 1 if ("_0.9" in filename_base or
"_0.8.4" in filename_base) else 0
Expand Down Expand Up @@ -465,6 +465,27 @@ def test_joblib_pickle_across_python_versions():
_check_pickle(fname, expected_list)


@with_numpy
def test_joblib_pickle_across_python_versions_with_mmap():
expected_list = [np.arange(5, dtype=np.dtype('<i8')),
np.arange(5, dtype=np.dtype('<f8')),
np.array([1, 'abc', {'a': 1, 'b': 2}], dtype='O'),
np.arange(256, dtype=np.uint8).tobytes(),
# np.matrix is a subclass of np.ndarray, here we want
# to verify this type of object is correctly unpickled
# among versions.
np.matrix([0, 1, 2], dtype=np.dtype('<i8')),
u"C'est l'\xe9t\xe9 !"]

test_data_dir = os.path.dirname(os.path.abspath(data.__file__))

pickle_filenames = [
os.path.join(test_data_dir, fn)
for fn in os.listdir(test_data_dir) if fn.endswith('.pkl')]
for fname in pickle_filenames:
_check_pickle(fname, expected_list, mmap_mode='r')


@with_numpy
def test_numpy_array_byte_order_mismatch_detection():
# List of numpy arrays with big endian byteorder.
Expand Down Expand Up @@ -1054,3 +1075,55 @@ def test_lz4_compression_without_lz4(tmpdir):
with raises(ValueError) as excinfo:
numpy_pickle.dump(data, fname + '.lz4')
excinfo.match(msg)


@with_numpy
@parametrize('protocol', range(0, pickle.HIGHEST_PROTOCOL + 1))
lesteve marked this conversation as resolved.
Show resolved Hide resolved
def test_memmap_alignment_padding(tmpdir, protocol):
# Test that memmaped arrays returned by numpy.load are correctly aligned
fname = tmpdir.join('test.mmap').strpath

a = np.random.randn(2)
numpy_pickle.dump(a, fname, protocol=protocol)
memmap = numpy_pickle.load(fname, mmap_mode='r')
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(a, memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned

array_list = [
np.random.randn(2), np.random.randn(2),
np.random.randn(2), np.random.randn(2)
]

# On Windows OSError 22 if reusing the same path for memmap ...
fname = tmpdir.join('test1.mmap').strpath
numpy_pickle.dump(array_list, fname, protocol=protocol)
l_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for idx, memmap in enumerate(l_reloaded):
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_list[idx], memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned

array_dict = {
'a1': np.random.randn(100),
'a2': np.random.randn(200),
'a3': np.random.randn(300),
'a4': np.random.randn(400)
}
lesteve marked this conversation as resolved.
Show resolved Hide resolved

# On Windows OSError 22 if reusing the same path for memmap ...
fname = tmpdir.join('test2.mmap').strpath
numpy_pickle.dump(array_dict, fname, protocol=protocol)
d_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for key, memmap in d_reloaded.items():
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_dict[key], memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned