From 452d4cdf8c197db24954a62cd573fb227e7cac73 Mon Sep 17 00:00:00 2001 From: Pradipta Ghosh Date: Wed, 5 May 2021 02:49:49 -0700 Subject: [PATCH] Fixing byte-order consistency/missmatch for cross-endian platform - Addressing https://github.com/joblib/joblib/issues/1123 --- joblib/numpy_pickle.py | 5 +++++ joblib/numpy_pickle_compat.py | 6 +++++- joblib/numpy_pickle_utils.py | 12 ++++++++++++ joblib/test/test_numpy_pickle.py | 7 +++++++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/joblib/numpy_pickle.py b/joblib/numpy_pickle.py index 93e5537ea..c9f098228 100644 --- a/joblib/numpy_pickle.py +++ b/joblib/numpy_pickle.py @@ -20,6 +20,7 @@ from .numpy_pickle_utils import Unpickler, Pickler from .numpy_pickle_utils import _read_fileobject, _write_fileobject from .numpy_pickle_utils import _read_bytes, BUFFER_SIZE +from .numpy_pickle_utils import _is_numpy_array_byte_order_mismatch from .numpy_pickle_compat import load_compatibility from .numpy_pickle_compat import NDArrayWrapper # For compatibility with old versions of joblib, we need ZNDArrayWrapper @@ -147,6 +148,10 @@ def read_array(self, unpickler): else: array.shape = self.shape + # Detect byte order mis-match and swap as needed. + if _is_numpy_array_byte_order_mismatch(array): + array = array.byteswap().newbyteorder('=') + return array def read_mmap(self, unpickler): diff --git a/joblib/numpy_pickle_compat.py b/joblib/numpy_pickle_compat.py index 6541a066a..32bcc4ed1 100644 --- a/joblib/numpy_pickle_compat.py +++ b/joblib/numpy_pickle_compat.py @@ -9,7 +9,7 @@ from .numpy_pickle_utils import _ZFILE_PREFIX from .numpy_pickle_utils import Unpickler - +from .numpy_pickle_utils import _is_numpy_array_byte_order_mismatch def hex_str(an_int): """Convert an int to an hexadecimal string.""" @@ -105,6 +105,10 @@ def read(self, unpickler): kwargs["allow_pickle"] = True array = unpickler.np.load(filename, **kwargs) + # Detect byte order mis-match and swap as needed. + if _is_numpy_array_byte_order_mismatch(array): + array = array.byteswap().newbyteorder('=') + # Reconstruct subclasses. This does not work with old # versions of numpy if (hasattr(array, '__array_prepare__') and diff --git a/joblib/numpy_pickle_utils.py b/joblib/numpy_pickle_utils.py index a50105547..a34c5bf3f 100644 --- a/joblib/numpy_pickle_utils.py +++ b/joblib/numpy_pickle_utils.py @@ -6,6 +6,7 @@ import pickle import io +import sys import warnings import contextlib @@ -47,6 +48,17 @@ def _get_prefixes_max_len(): prefixes += [len(_ZFILE_PREFIX)] return max(prefixes) +def _is_numpy_array_byte_order_mismatch(array): + """Check if numpy array is having byte order mis-match""" + return \ + (sys.byteorder == 'big' and \ + (array.dtype.byteorder == '<' or \ + (array.dtype.byteorder == '|' and array.dtype.fields and \ + all(e[0].byteorder == '<' for e in array.dtype.fields.values())))) or \ + (sys.byteorder == 'little' and \ + (array.dtype.byteorder == '>' or \ + (array.dtype.byteorder == '|' and array.dtype.fields and \ + all(e[0].byteorder == '>' for e in array.dtype.fields.values())))) ############################################################################### # Cache file utilities diff --git a/joblib/test/test_numpy_pickle.py b/joblib/test/test_numpy_pickle.py index db130b1f4..9a8c6cb4e 100644 --- a/joblib/test/test_numpy_pickle.py +++ b/joblib/test/test_numpy_pickle.py @@ -30,6 +30,7 @@ from joblib.numpy_pickle_utils import _IO_BUFFER_SIZE from joblib.numpy_pickle_utils import _detect_compressor +from joblib.numpy_pickle_utils import _is_numpy_array_byte_order_mismatch from joblib.compressor import (_COMPRESSORS, _LZ4_PREFIX, CompressorWrapper, LZ4_NOT_INSTALLED_ERROR, BinaryZlibFile) @@ -355,6 +356,10 @@ def test_compressed_pickle_dump_and_load(tmpdir): result_list = numpy_pickle.load(fname) for result, expected in zip(result_list, expected_list): if isinstance(expected, np.ndarray): + + if _is_numpy_array_byte_order_mismatch(expected): + expected = expected.byteswap().newbyteorder('=') + assert result.dtype == expected.dtype np.testing.assert_equal(result, expected) else: @@ -394,6 +399,8 @@ def _check_pickle(filename, expected_list): "pickle file.".format(filename)) for result, expected in zip(result_list, expected_list): if isinstance(expected, np.ndarray): + if _is_numpy_array_byte_order_mismatch(expected): + expected = expected.byteswap().newbyteorder('=') assert result.dtype == expected.dtype np.testing.assert_equal(result, expected) else: