Skip to content

Commit

Permalink
Fixing byte-order consistency/missmatch for cross-endian platform
Browse files Browse the repository at this point in the history
- Addressing joblib#1123
  • Loading branch information
pradghos committed May 5, 2021
1 parent 754433f commit 452d4cd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 1 deletion.
5 changes: 5 additions & 0 deletions joblib/numpy_pickle.py
Expand Up @@ -20,6 +20,7 @@
from .numpy_pickle_utils import Unpickler, Pickler
from .numpy_pickle_utils import _read_fileobject, _write_fileobject
from .numpy_pickle_utils import _read_bytes, BUFFER_SIZE
from .numpy_pickle_utils import _is_numpy_array_byte_order_mismatch
from .numpy_pickle_compat import load_compatibility
from .numpy_pickle_compat import NDArrayWrapper
# For compatibility with old versions of joblib, we need ZNDArrayWrapper
Expand Down Expand Up @@ -147,6 +148,10 @@ def read_array(self, unpickler):
else:
array.shape = self.shape

# Detect byte order mis-match and swap as needed.
if _is_numpy_array_byte_order_mismatch(array):
array = array.byteswap().newbyteorder('=')

return array

def read_mmap(self, unpickler):
Expand Down
6 changes: 5 additions & 1 deletion joblib/numpy_pickle_compat.py
Expand Up @@ -9,7 +9,7 @@

from .numpy_pickle_utils import _ZFILE_PREFIX
from .numpy_pickle_utils import Unpickler

from .numpy_pickle_utils import _is_numpy_array_byte_order_mismatch

def hex_str(an_int):
"""Convert an int to an hexadecimal string."""
Expand Down Expand Up @@ -105,6 +105,10 @@ def read(self, unpickler):
kwargs["allow_pickle"] = True
array = unpickler.np.load(filename, **kwargs)

# Detect byte order mis-match and swap as needed.
if _is_numpy_array_byte_order_mismatch(array):
array = array.byteswap().newbyteorder('=')

# Reconstruct subclasses. This does not work with old
# versions of numpy
if (hasattr(array, '__array_prepare__') and
Expand Down
12 changes: 12 additions & 0 deletions joblib/numpy_pickle_utils.py
Expand Up @@ -6,6 +6,7 @@

import pickle
import io
import sys
import warnings
import contextlib

Expand Down Expand Up @@ -47,6 +48,17 @@ def _get_prefixes_max_len():
prefixes += [len(_ZFILE_PREFIX)]
return max(prefixes)

def _is_numpy_array_byte_order_mismatch(array):
"""Check if numpy array is having byte order mis-match"""
return \
(sys.byteorder == 'big' and \
(array.dtype.byteorder == '<' or \
(array.dtype.byteorder == '|' and array.dtype.fields and \
all(e[0].byteorder == '<' for e in array.dtype.fields.values())))) or \
(sys.byteorder == 'little' and \
(array.dtype.byteorder == '>' or \
(array.dtype.byteorder == '|' and array.dtype.fields and \
all(e[0].byteorder == '>' for e in array.dtype.fields.values()))))

###############################################################################
# Cache file utilities
Expand Down
7 changes: 7 additions & 0 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -30,6 +30,7 @@

from joblib.numpy_pickle_utils import _IO_BUFFER_SIZE
from joblib.numpy_pickle_utils import _detect_compressor
from joblib.numpy_pickle_utils import _is_numpy_array_byte_order_mismatch
from joblib.compressor import (_COMPRESSORS, _LZ4_PREFIX, CompressorWrapper,
LZ4_NOT_INSTALLED_ERROR, BinaryZlibFile)

Expand Down Expand Up @@ -355,6 +356,10 @@ def test_compressed_pickle_dump_and_load(tmpdir):
result_list = numpy_pickle.load(fname)
for result, expected in zip(result_list, expected_list):
if isinstance(expected, np.ndarray):

if _is_numpy_array_byte_order_mismatch(expected):
expected = expected.byteswap().newbyteorder('=')

assert result.dtype == expected.dtype
np.testing.assert_equal(result, expected)
else:
Expand Down Expand Up @@ -394,6 +399,8 @@ def _check_pickle(filename, expected_list):
"pickle file.".format(filename))
for result, expected in zip(result_list, expected_list):
if isinstance(expected, np.ndarray):
if _is_numpy_array_byte_order_mismatch(expected):
expected = expected.byteswap().newbyteorder('=')
assert result.dtype == expected.dtype
np.testing.assert_equal(result, expected)
else:
Expand Down

0 comments on commit 452d4cd

Please sign in to comment.