Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure arrays are bytes aligned in joblib pickles #1254

Merged
merged 31 commits into from Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5b855cc
wip
lesteve Feb 2, 2022
232588b
fix
lesteve Feb 2, 2022
282d953
test
lesteve Feb 2, 2022
6b72bcf
Remove debug print
lesteve Feb 3, 2022
ef5ddcd
fix (peek not supported in io.BytesIO)
lesteve Feb 3, 2022
a851f5c
better variables
lesteve Feb 3, 2022
7cd90ed
handle case where .tell not supported
lesteve Feb 3, 2022
e0a994a
Fix memmap with old pickles loaded in mmap mode
lesteve Feb 3, 2022
c8302dc
Use module variable for array bytes alignment
lesteve Feb 3, 2022
4952732
Fix for Windows
lesteve Feb 4, 2022
36cfa85
Merge branch 'master' of https://github.com/joblib/joblib into memmap…
lesteve Feb 4, 2022
513e0d1
Fix test checking old non-aligned arrays with memmap
lesteve Feb 4, 2022
d7cbd3f
Add align_bytes_array attribute in NumpyArrayWrapper
lesteve Feb 8, 2022
d66f2bd
lint
lesteve Feb 8, 2022
7dab61b
Tweak
lesteve Feb 9, 2022
f943214
trigger CI
lesteve Feb 10, 2022
fe5036d
Trigger CI
lesteve Feb 17, 2022
0522f23
Create a numpy_array_alignment_bytes attribute to allow for future ch…
lesteve Feb 17, 2022
41dc185
Add test for edge case where aligned numpy arrays are read in a file …
lesteve Feb 17, 2022
b2b2f92
Merge branch 'master' into memmap-align
ogrisel Feb 23, 2022
73eb719
Tackled most comments
lesteve Feb 24, 2022
e6409e9
Add warning when loading an old pickle with non memory aligned array
lesteve Feb 24, 2022
a36af7c
Add changelog entry.
lesteve Feb 24, 2022
54c2773
tweak
lesteve Feb 24, 2022
dfccb94
Forgotten f-string
lesteve Feb 24, 2022
236e042
Windows fix
lesteve Feb 24, 2022
3995ed4
Trigger CI
lesteve Feb 24, 2022
4705aad
Apply suggestions from code review
lesteve Feb 25, 2022
2e380bb
tweak
lesteve Feb 25, 2022
eb69803
Update joblib/test/test_numpy_pickle.py
ogrisel Feb 25, 2022
188c089
Trigger CI
lesteve Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 34 additions & 1 deletion joblib/numpy_pickle.py
Expand Up @@ -7,6 +7,8 @@
import pickle
import os
import warnings
import io

try:
from pathlib import Path
except ImportError:
Expand Down Expand Up @@ -95,6 +97,17 @@ def write_array(self, array, pickler):
# pickle protocol.
pickle.dump(array, pickler.file_handle, protocol=2)
else:
try:
current_pos = pickler.file_handle.tell()
alignment = current_pos % 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy documentation mentions that some dtypes would rather expect 16 bytes alignment (e.g. float128).

Also, since SIMD-optimized compute kernels would run more efficient (fully vectorized) if the buffers are directly aligned to their vector instructions sizes, maybe we should directly go for 64 bytes alignment (e.g. for AVX 512 which are currently the widest vector instructions).

In the ARM ecosystem there are also 512 bit wide vector instructions, e.g.:

https://www.fujitsu.com/global/products/computing/servers/supercomputer/a64fx/

But from what I read about SVE2 the size can be dynamic by 128 bit (16 bytes) increments.

So I have the feeling that padding by 16 bytes is a necessity to be safe (avoid crashs) but padding by 64 bytes (512 bits) can be a bit helpful for vectorized compute kernels to run more efficiently on such memmaped buffers. Going beyond is probably useless.


if alignment != 0:
padding = b' ' * (8 - alignment)
pickler.file_handle.write(padding)
except io.UnsupportedOperation:
# TODO log something somewhere?
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
pass

for chunk in pickler.np.nditer(array,
flags=['external_loop',
'buffered',
Expand All @@ -121,6 +134,21 @@ def read_array(self, unpickler):
# The array contained Python objects. We need to unpickle the data.
array = pickle.load(unpickler.file_handle)
else:
try:
current_pos = unpickler.file_handle.tell()
alignment = current_pos % 8
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

# peek not supported in io.BytesIO ...
current_byte = unpickler.file_handle.read(1)
unpickler.file_handle.seek(current_pos)

if alignment != 0 and current_byte == b' ':
padding_length = 8 - alignment
unpickler.file_handle.seek(current_pos + padding_length)
except io.UnsupportedOperation:
# TODO log something somewhere?
pass

# This is not a real file. We have to read it the
# memory-intensive way.
# crc32 module fails on reads greater than 2 ** 32 bytes,
Expand Down Expand Up @@ -153,7 +181,12 @@ def read_array(self, unpickler):

def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
offset = unpickler.file_handle.tell()
current_pos = unpickler.file_handle.tell()
offset = current_pos
alignment = current_pos % 8
# Do I need to check whether current byte is b' '?
if alignment != 0:
offset += 8 - alignment
if unpickler.mmap_mode == 'w+':
unpickler.mmap_mode = 'r+'

Expand Down
46 changes: 46 additions & 0 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -1056,3 +1056,49 @@ def test_lz4_compression_without_lz4(tmpdir):
with raises(ValueError) as excinfo:
numpy_pickle.dump(data, fname + '.lz4')
excinfo.match(msg)


@with_numpy
@parametrize('protocol', range(0, pickle.HIGHEST_PROTOCOL + 1))
lesteve marked this conversation as resolved.
Show resolved Hide resolved
def test_memmap_alignment_padding(tmpdir, protocol):
# Test that memmaped arrays returned by numpy.load are correctly aligned
fname = tmpdir.join('test.mmap').strpath

a = np.random.randn(2)
numpy_pickle.dump(a, fname, protocol=protocol)
memmap = numpy_pickle.load(fname, mmap_mode='r')
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(a, memmap)
assert memmap.ctypes.data % 8 == 0
assert memmap.flags.aligned

array_list = [
np.random.randn(2), np.random.randn(2),
np.random.randn(2), np.random.randn(2)
]

numpy_pickle.dump(array_list, fname, protocol=protocol)
l_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for idx, memmap in enumerate(l_reloaded):
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_list[idx], memmap)
print("MODULO: {}".format(memmap.ctypes.data % 8))
assert memmap.ctypes.data % 8 == 0
assert memmap.flags.aligned

array_dict = {
'a1': np.random.randn(100),
'a2': np.random.randn(200),
'a3': np.random.randn(300),
'a4': np.random.randn(400)
}
lesteve marked this conversation as resolved.
Show resolved Hide resolved

numpy_pickle.dump(array_dict, fname, protocol=protocol)
d_reloaded = numpy_pickle.load(fname, mmap_mode='r')

for key, memmap in d_reloaded.items():
assert isinstance(memmap, np.memmap)
np.testing.assert_array_equal(array_dict[key], memmap)
assert memmap.ctypes.data % 8 == 0
assert memmap.flags.aligned