Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure arrays are bytes aligned in joblib pickles #1254

Merged
merged 31 commits into from Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5b855cc
wip
lesteve Feb 2, 2022
232588b
fix
lesteve Feb 2, 2022
282d953
test
lesteve Feb 2, 2022
6b72bcf
Remove debug print
lesteve Feb 3, 2022
ef5ddcd
fix (peek not supported in io.BytesIO)
lesteve Feb 3, 2022
a851f5c
better variables
lesteve Feb 3, 2022
7cd90ed
handle case where .tell not supported
lesteve Feb 3, 2022
e0a994a
Fix memmap with old pickles loaded in mmap mode
lesteve Feb 3, 2022
c8302dc
Use module variable for array bytes alignment
lesteve Feb 3, 2022
4952732
Fix for Windows
lesteve Feb 4, 2022
36cfa85
Merge branch 'master' of https://github.com/joblib/joblib into memmap…
lesteve Feb 4, 2022
513e0d1
Fix test checking old non-aligned arrays with memmap
lesteve Feb 4, 2022
d7cbd3f
Add align_bytes_array attribute in NumpyArrayWrapper
lesteve Feb 8, 2022
d66f2bd
lint
lesteve Feb 8, 2022
7dab61b
Tweak
lesteve Feb 9, 2022
f943214
trigger CI
lesteve Feb 10, 2022
fe5036d
Trigger CI
lesteve Feb 17, 2022
0522f23
Create a numpy_array_alignment_bytes attribute to allow for future ch…
lesteve Feb 17, 2022
41dc185
Add test for edge case where aligned numpy arrays are read in a file …
lesteve Feb 17, 2022
b2b2f92
Merge branch 'master' into memmap-align
ogrisel Feb 23, 2022
73eb719
Tackled most comments
lesteve Feb 24, 2022
e6409e9
Add warning when loading an old pickle with non memory aligned array
lesteve Feb 24, 2022
a36af7c
Add changelog entry.
lesteve Feb 24, 2022
54c2773
tweak
lesteve Feb 24, 2022
dfccb94
Forgotten f-string
lesteve Feb 24, 2022
236e042
Windows fix
lesteve Feb 24, 2022
3995ed4
Trigger CI
lesteve Feb 24, 2022
4705aad
Apply suggestions from code review
lesteve Feb 25, 2022
2e380bb
tweak
lesteve Feb 25, 2022
eb69803
Update joblib/test/test_numpy_pickle.py
ogrisel Feb 25, 2022
188c089
Trigger CI
lesteve Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 41 additions & 1 deletion joblib/numpy_pickle.py
Expand Up @@ -7,6 +7,7 @@
import pickle
import os
import warnings
import io
from pathlib import Path

from .compressor import lz4, LZ4_NOT_INSTALLED_ERROR
Expand Down Expand Up @@ -39,6 +40,8 @@
###############################################################################
# Utility objects for persistence.

NUMPY_ARRAY_ALIGNMENT_BYTES = 16

lesteve marked this conversation as resolved.
Show resolved Hide resolved

class NumpyArrayWrapper(object):
"""An object to be persisted instead of numpy arrays.
Expand Down Expand Up @@ -92,6 +95,17 @@ def write_array(self, array, pickler):
# pickle protocol.
pickle.dump(array, pickler.file_handle, protocol=2)
else:
try:
current_pos = pickler.file_handle.tell()
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES

if alignment != 0:
padding = b' ' * (NUMPY_ARRAY_ALIGNMENT_BYTES - alignment)
pickler.file_handle.write(padding)
except io.UnsupportedOperation:
# TODO log something somewhere?
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
pass

for chunk in pickler.np.nditer(array,
flags=['external_loop',
'buffered',
Expand All @@ -118,6 +132,21 @@ def read_array(self, unpickler):
# The array contained Python objects. We need to unpickle the data.
array = pickle.load(unpickler.file_handle)
else:
try:
current_pos = unpickler.file_handle.tell()
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES

# peek not supported in io.BytesIO ...
current_byte = unpickler.file_handle.read(1)
unpickler.file_handle.seek(current_pos)

if alignment != 0 and current_byte == b' ':
padding_length = NUMPY_ARRAY_ALIGNMENT_BYTES - alignment
unpickler.file_handle.seek(current_pos + padding_length)
except io.UnsupportedOperation:
# TODO log something somewhere?
pass

# This is not a real file. We have to read it the
# memory-intensive way.
# crc32 module fails on reads greater than 2 ** 32 bytes,
Expand Down Expand Up @@ -150,7 +179,18 @@ def read_array(self, unpickler):

def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
offset = unpickler.file_handle.tell()
current_pos = unpickler.file_handle.tell()
offset = current_pos
alignment = current_pos % NUMPY_ARRAY_ALIGNMENT_BYTES

# peek not supported in io.BytesIO ...
current_byte = unpickler.file_handle.read(1)
unpickler.file_handle.seek(current_pos)

if alignment != 0 and current_byte == b' ':
padding_length = NUMPY_ARRAY_ALIGNMENT_BYTES - alignment
offset += padding_length

lesteve marked this conversation as resolved.
Show resolved Hide resolved
if unpickler.mmap_mode == 'w+':
unpickler.mmap_mode = 'r+'

Expand Down
73 changes: 73 additions & 0 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -465,6 +465,27 @@ def test_joblib_pickle_across_python_versions():
_check_pickle(fname, expected_list)


@with_numpy
def test_joblib_pickle_across_python_versions_with_mmap():
expected_list = [np.arange(5, dtype=np.dtype('<i8')),
np.arange(5, dtype=np.dtype('<f8')),
np.array([1, 'abc', {'a': 1, 'b': 2}], dtype='O'),
np.arange(256, dtype=np.uint8).tobytes(),
# np.matrix is a subclass of np.ndarray, here we want
# to verify this type of object is correctly unpickled
# among versions.
np.matrix([0, 1, 2], dtype=np.dtype('<i8')),
u"C'est l'\xe9t\xe9 !"]

test_data_dir = os.path.dirname(os.path.abspath(data.__file__))

pickle_filenames = [
os.path.join(test_data_dir, fn)
for fn in os.listdir(test_data_dir) if fn.endswith('.pkl')]
for fname in pickle_filenames:
_check_pickle(fname, expected_list)
ogrisel marked this conversation as resolved.
Show resolved Hide resolved


@with_numpy
def test_numpy_array_byte_order_mismatch_detection():
# List of numpy arrays with big endian byteorder.
Expand Down Expand Up @@ -1054,3 +1075,55 @@ def test_lz4_compression_without_lz4(tmpdir):
with raises(ValueError) as excinfo:
numpy_pickle.dump(data, fname + '.lz4')
excinfo.match(msg)


@with_numpy
@parametrize('protocol', range(0, pickle.HIGHEST_PROTOCOL + 1))
lesteve marked this conversation as resolved.
Show resolved Hide resolved
def test_memmap_alignment_padding(tmpdir, protocol):
# Test that memmaped arrays returned by numpy.load are correctly aligned
fname = tmpdir.join('test.mmap').strpath

a = np.random.randn(2)
numpy_pickle.dump(a, fname, protocol=protocol)
memmap = numpy_pickle.load(fname, mmap_mode='r')
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(a, memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned

array_list = [
np.random.randn(2), np.random.randn(2),
np.random.randn(2), np.random.randn(2)
]

# On Windows OSError 22 if reusing the same path for memmap ...
fname = tmpdir.join('test1.mmap').strpath
numpy_pickle.dump(array_list, fname, protocol=protocol)
l_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for idx, memmap in enumerate(l_reloaded):
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_list[idx], memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned

array_dict = {
'a1': np.random.randn(100),
'a2': np.random.randn(200),
'a3': np.random.randn(300),
'a4': np.random.randn(400)
}
lesteve marked this conversation as resolved.
Show resolved Hide resolved

# On Windows OSError 22 if reusing the same path for memmap ...
fname = tmpdir.join('test2.mmap').strpath
numpy_pickle.dump(array_dict, fname, protocol=protocol)
d_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for key, memmap in d_reloaded.items():
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_dict[key], memmap)
assert (
memmap.ctypes.data % numpy_pickle.NUMPY_ARRAY_ALIGNMENT_BYTES == 0)
assert memmap.flags.aligned