Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test decision tree pickle for different endianness #21539

Merged
merged 4 commits into from Nov 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -5,6 +5,8 @@
import pickle
from itertools import product
import struct
import io
import copyreg

import pytest
import numpy as np
Expand All @@ -13,6 +15,9 @@
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix

import joblib
from joblib.numpy_pickle import NumpyPickler

from sklearn.random_projection import _sparse_random_matrix

from sklearn.dummy import DummyRegressor
Expand All @@ -32,6 +37,7 @@

from sklearn.utils.estimator_checks import check_sample_weights_invariance
from sklearn.utils.validation import check_random_state
from sklearn.utils import parse_version

from sklearn.exceptions import NotFittedError

Expand Down Expand Up @@ -2179,3 +2185,58 @@ def test_n_features_deprecated(Tree):

with pytest.warns(FutureWarning, match=depr_msg):
Tree().fit(X, y).n_features_


def test_different_endianness_pickle():
X, y = datasets.make_classification(random_state=0)

clf = DecisionTreeClassifier(random_state=0, max_depth=3)
clf.fit(X, y)
score = clf.score(X, y)

def reduce_ndarray(arr):
return arr.byteswap().newbyteorder().__reduce__()

def get_pickle_non_native_endianness():
f = io.BytesIO()
p = pickle.Pickler(f)
p.dispatch_table = copyreg.dispatch_table.copy()
p.dispatch_table[np.ndarray] = reduce_ndarray

p.dump(clf)
f.seek(0)
return f

new_clf = pickle.load(get_pickle_non_native_endianness())
new_score = new_clf.score(X, y)
assert np.isclose(score, new_score)


@pytest.mark.skipif(
parse_version(joblib.__version__) < parse_version("1.1"),
reason="joblib >= 1.1 is needed to load numpy arrays in native endianness",
)
def test_different_endianness_joblib_pickle():
X, y = datasets.make_classification(random_state=0)

clf = DecisionTreeClassifier(random_state=0, max_depth=3)
clf.fit(X, y)
score = clf.score(X, y)

class NonNativeEndiannessNumpyPickler(NumpyPickler):
def save(self, obj):
if isinstance(obj, np.ndarray):
obj = obj.byteswap().newbyteorder()
super().save(obj)

def get_joblib_pickle_non_native_endianness():
f = io.BytesIO()
p = NonNativeEndiannessNumpyPickler(f)

p.dump(clf)
f.seek(0)
return f

new_clf = joblib.load(get_joblib_pickle_non_native_endianness())
new_score = new_clf.score(X, y)
assert np.isclose(score, new_score)