Skip to content

Commit

Permalink
Fixing byte-order consistency/mismatch for cross-endian platform (#1181)
Browse files Browse the repository at this point in the history

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
pradghos and ogrisel committed Jun 11, 2021
1 parent e96152f commit 0fa2cb9
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Expand Up @@ -4,6 +4,12 @@ Latest changes
Development version
-------------------

- Fix byte order inconsistency issue during deserialization using joblib.load
in cross-endian environment: the numpy arrays are now always loaded to
use the system byte order, independently of the byte order of the system
that serialized the pickle.
https://github.com/joblib/joblib/pull/1181

- Fix joblib.Memory bug with the ``ignore`` parameter when the cached function
is a decorated function.
https://github.com/joblib/joblib/pull/1165
Expand Down
4 changes: 3 additions & 1 deletion joblib/numpy_pickle.py
Expand Up @@ -20,6 +20,7 @@
from .numpy_pickle_utils import Unpickler, Pickler
from .numpy_pickle_utils import _read_fileobject, _write_fileobject
from .numpy_pickle_utils import _read_bytes, BUFFER_SIZE
from .numpy_pickle_utils import _ensure_native_byte_order
from .numpy_pickle_compat import load_compatibility
from .numpy_pickle_compat import NDArrayWrapper
# For compatibility with old versions of joblib, we need ZNDArrayWrapper
Expand Down Expand Up @@ -147,7 +148,8 @@ def read_array(self, unpickler):
else:
array.shape = self.shape

return array
# Detect byte order mis-match and swap as needed.
return _ensure_native_byte_order(array)

def read_mmap(self, unpickler):
"""Read an array using numpy memmap."""
Expand Down
5 changes: 4 additions & 1 deletion joblib/numpy_pickle_compat.py
Expand Up @@ -9,7 +9,7 @@

from .numpy_pickle_utils import _ZFILE_PREFIX
from .numpy_pickle_utils import Unpickler

from .numpy_pickle_utils import _ensure_native_byte_order

def hex_str(an_int):
"""Convert an int to an hexadecimal string."""
Expand Down Expand Up @@ -105,6 +105,9 @@ def read(self, unpickler):
kwargs["allow_pickle"] = True
array = unpickler.np.load(filename, **kwargs)

# Detect byte order mis-match and swap as needed.
array = _ensure_native_byte_order(array)

# Reconstruct subclasses. This does not work with old
# versions of numpy
if (hasattr(array, '__array_prepare__') and
Expand Down
25 changes: 25 additions & 0 deletions joblib/numpy_pickle_utils.py
Expand Up @@ -6,6 +6,7 @@

import pickle
import io
import sys
import warnings
import contextlib

Expand Down Expand Up @@ -48,6 +49,30 @@ def _get_prefixes_max_len():
return max(prefixes)


def _is_numpy_array_byte_order_mismatch(array):
"""Check if numpy array is having byte order mis-match"""
return ((sys.byteorder == 'big' and
(array.dtype.byteorder == '<' or
(array.dtype.byteorder == '|' and array.dtype.fields and
all(e[0].byteorder == '<'
for e in array.dtype.fields.values())))) or
(sys.byteorder == 'little' and
(array.dtype.byteorder == '>' or
(array.dtype.byteorder == '|' and array.dtype.fields and
all(e[0].byteorder == '>'
for e in array.dtype.fields.values())))))


def _ensure_native_byte_order(array):
"""Use the byte order of the host while preserving values
Does nothing if array already uses the system byte order.
"""
if _is_numpy_array_byte_order_mismatch(array):
array = array.byteswap().newbyteorder('=')
return array


###############################################################################
# Cache file utilities
def _detect_compressor(fileobj):
Expand Down
46 changes: 46 additions & 0 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -5,6 +5,7 @@
import random
import re
import io
import sys
import warnings
import gzip
import zlib
Expand All @@ -30,6 +31,8 @@

from joblib.numpy_pickle_utils import _IO_BUFFER_SIZE
from joblib.numpy_pickle_utils import _detect_compressor
from joblib.numpy_pickle_utils import _is_numpy_array_byte_order_mismatch
from joblib.numpy_pickle_utils import _ensure_native_byte_order
from joblib.compressor import (_COMPRESSORS, _LZ4_PREFIX, CompressorWrapper,
LZ4_NOT_INSTALLED_ERROR, BinaryZlibFile)

Expand Down Expand Up @@ -355,6 +358,7 @@ def test_compressed_pickle_dump_and_load(tmpdir):
result_list = numpy_pickle.load(fname)
for result, expected in zip(result_list, expected_list):
if isinstance(expected, np.ndarray):
expected = _ensure_native_byte_order(expected)
assert result.dtype == expected.dtype
np.testing.assert_equal(result, expected)
else:
Expand Down Expand Up @@ -394,6 +398,7 @@ def _check_pickle(filename, expected_list):
"pickle file.".format(filename))
for result, expected in zip(result_list, expected_list):
if isinstance(expected, np.ndarray):
expected = _ensure_native_byte_order(expected)
assert result.dtype == expected.dtype
np.testing.assert_equal(result, expected)
else:
Expand Down Expand Up @@ -457,6 +462,47 @@ def test_joblib_pickle_across_python_versions():
_check_pickle(fname, expected_list)


@with_numpy
def test_numpy_array_byte_order_mismatch_detection():
# List of numpy arrays with big endian byteorder.
be_arrays = [np.array([(1, 2.0), (3, 4.0)],
dtype=[('', '>i8'), ('', '>f8')]),
np.arange(3, dtype=np.dtype('>i8')),
np.arange(3, dtype=np.dtype('>f8'))]

# Verify the byteorder mismatch is correctly detected.
for array in be_arrays:
if sys.byteorder == 'big':
assert not _is_numpy_array_byte_order_mismatch(array)
else:
assert _is_numpy_array_byte_order_mismatch(array)
converted = _ensure_native_byte_order(array)
if converted.dtype.fields:
for f in converted.dtype.fields.values():
f[0].byteorder == '='
else:
assert converted.dtype.byteorder == "="

# List of numpy arrays with little endian byteorder.
le_arrays = [np.array([(1, 2.0), (3, 4.0)],
dtype=[('', '<i8'), ('', '<f8')]),
np.arange(3, dtype=np.dtype('<i8')),
np.arange(3, dtype=np.dtype('<f8'))]

# Verify the byteorder mismatch is correctly detected.
for array in le_arrays:
if sys.byteorder == 'little':
assert not _is_numpy_array_byte_order_mismatch(array)
else:
assert _is_numpy_array_byte_order_mismatch(array)
converted = _ensure_native_byte_order(array)
if converted.dtype.fields:
for f in converted.dtype.fields.values():
f[0].byteorder == '='
else:
assert converted.dtype.byteorder == "="


@parametrize('compress_tuple', [('zlib', 3), ('gzip', 3)])
def test_compress_tuple_argument(tmpdir, compress_tuple):
# Verify the tuple is correctly taken into account.
Expand Down

0 comments on commit 0fa2cb9

Please sign in to comment.