From 1d567c10fca1668fea2a7157a7889dcec4b18e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 3 Nov 2021 13:19:28 +0100 Subject: [PATCH 1/4] Add test for cross-architecture-pickle --- sklearn/tree/tests/test_tree.py | 50 +++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index cee55d2c40d8d..6efcd9e4bc60b 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -5,6 +5,8 @@ import pickle from itertools import product import struct +import io +import copyreg import pytest import numpy as np @@ -13,6 +15,9 @@ from scipy.sparse import csr_matrix from scipy.sparse import coo_matrix +import joblib +from joblib.numpy_pickle import NumpyPickler + from sklearn.random_projection import _sparse_random_matrix from sklearn.dummy import DummyRegressor @@ -2179,3 +2184,48 @@ def test_n_features_deprecated(Tree): with pytest.warns(FutureWarning, match=depr_msg): Tree().fit(X, y).n_features_ + + +def test_different_endianness_pickle(): + X, y = datasets.make_classification(random_state=0) + + clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf.fit(X, y) + + def reduce_ndarray(arr): + return arr.byteswap().newbyteorder().__reduce__() + + def get_pickle_non_native_endianness(): + f = io.BytesIO() + p = pickle.Pickler(f) + p.dispatch_table = copyreg.dispatch_table.copy() + p.dispatch_table[np.ndarray] = reduce_ndarray + + p.dump(clf.tree_) + f.seek(0) + return f + + pickle.load(get_pickle_non_native_endianness()) + + +def test_different_endianness_joblib_pickle(): + X, y = datasets.make_classification(random_state=0) + + clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf.fit(X, y) + + class MyNumpyPickler(NumpyPickler): + def save(self, obj): + if isinstance(obj, np.ndarray): + obj = obj.byteswap().newbyteorder() + super().save(obj) + + def get_joblib_pickle_non_native_endianness(): + f = io.BytesIO() + p = MyNumpyPickler(f) + + p.dump(clf.tree_) + f.seek(0) + return f + + joblib.load(get_joblib_pickle_non_native_endianness()) From 0ee5488fb23bdb8fee03f400a9c38355c4676f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 3 Nov 2021 14:18:39 +0100 Subject: [PATCH 2/4] Tweak tests. --- sklearn/tree/tests/test_tree.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 6efcd9e4bc60b..b8dcc185e1d3d 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2191,6 +2191,7 @@ def test_different_endianness_pickle(): clf = DecisionTreeClassifier(random_state=0, max_depth=3) clf.fit(X, y) + score = clf.score(X, y) def reduce_ndarray(arr): return arr.byteswap().newbyteorder().__reduce__() @@ -2201,11 +2202,13 @@ def get_pickle_non_native_endianness(): p.dispatch_table = copyreg.dispatch_table.copy() p.dispatch_table[np.ndarray] = reduce_ndarray - p.dump(clf.tree_) + p.dump(clf) f.seek(0) return f - pickle.load(get_pickle_non_native_endianness()) + new_clf = pickle.load(get_pickle_non_native_endianness()) + new_score = new_clf.score(X, y) + assert np.isclose(score, new_score) def test_different_endianness_joblib_pickle(): @@ -2213,8 +2216,9 @@ def test_different_endianness_joblib_pickle(): clf = DecisionTreeClassifier(random_state=0, max_depth=3) clf.fit(X, y) + score = clf.score(X, y) - class MyNumpyPickler(NumpyPickler): + class NonNativeEndiannessNumpyPickler(NumpyPickler): def save(self, obj): if isinstance(obj, np.ndarray): obj = obj.byteswap().newbyteorder() @@ -2222,10 +2226,12 @@ def save(self, obj): def get_joblib_pickle_non_native_endianness(): f = io.BytesIO() - p = MyNumpyPickler(f) + p = NonNativeEndiannessNumpyPickler(f) - p.dump(clf.tree_) + p.dump(clf) f.seek(0) return f - joblib.load(get_joblib_pickle_non_native_endianness()) + new_clf = joblib.load(get_joblib_pickle_non_native_endianness()) + new_score = new_clf.score(X, y) + assert np.isclose(score, new_score) From ecb5021e06b8141e743b8394abdf54ea06f3015d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 3 Nov 2021 15:00:56 +0100 Subject: [PATCH 3/4] Skip test for joblib < 1.1 --- sklearn/tree/tests/test_tree.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index b8dcc185e1d3d..141a18bf2cb0e 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -37,6 +37,7 @@ from sklearn.utils.estimator_checks import check_sample_weights_invariance from sklearn.utils.validation import check_random_state +from sklearn.utils import parse_version from sklearn.exceptions import NotFittedError @@ -2212,6 +2213,11 @@ def get_pickle_non_native_endianness(): def test_different_endianness_joblib_pickle(): + if parse_version(joblib.__version__) < parse_version("1.1"): + pytest.skip( + "joblib >= 1.1 is needed to load numpy arrays in native endianness" + ) + X, y = datasets.make_classification(random_state=0) clf = DecisionTreeClassifier(random_state=0, max_depth=3) From 0a1abdacaf011fa423f06a422778d203c120c7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 3 Nov 2021 15:03:30 +0100 Subject: [PATCH 4/4] Use pytest.mark.skipif --- sklearn/tree/tests/test_tree.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 141a18bf2cb0e..90da2e04d6a9f 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2212,12 +2212,11 @@ def get_pickle_non_native_endianness(): assert np.isclose(score, new_score) +@pytest.mark.skipif( + parse_version(joblib.__version__) < parse_version("1.1"), + reason="joblib >= 1.1 is needed to load numpy arrays in native endianness", +) def test_different_endianness_joblib_pickle(): - if parse_version(joblib.__version__) < parse_version("1.1"): - pytest.skip( - "joblib >= 1.1 is needed to load numpy arrays in native endianness" - ) - X, y = datasets.make_classification(random_state=0) clf = DecisionTreeClassifier(random_state=0, max_depth=3)