Skip to content

Commit

Permalink
Make function pickling deterministic by default
Browse files Browse the repository at this point in the history
Co-authored-by: Christopher J. Markiewicz <markiewicz@stanford.edu>
  • Loading branch information
ogrisel and effigies committed Jun 30, 2021
1 parent 2a79b69 commit afdea9e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Expand Up @@ -15,6 +15,10 @@ dev
_is_parametrized_type_hint to limit false positives.
([PR #409](https://github.com/cloudpipe/cloudpickle/pull/409))

- Suppressed a source of non-determinism when pickling dynamically defined
functions.
([PR #428](https://github.com/cloudpipe/cloudpickle/pull/428))

1.6.0
=====

Expand Down
7 changes: 5 additions & 2 deletions cloudpickle/cloudpickle.py
Expand Up @@ -236,7 +236,10 @@ def _extract_code_globals(co):
out_names = _extract_code_globals_cache.get(co)
if out_names is None:
names = co.co_names
out_names = {names[oparg] for _, oparg in _walk_global_ops(co)}
# We use a dict with None values instead of a set to get a
# deterministic order (assuming Python 3.6+) and avoid introducing
# non-deterministic pickle bytes as a results.
out_names = {names[oparg]: None for _, oparg in _walk_global_ops(co)}

# Declaring a function inside another one using the "def ..."
# syntax generates a constant code object corresponding to the one
Expand All @@ -247,7 +250,7 @@ def _extract_code_globals(co):
if co.co_consts:
for const in co.co_consts:
if isinstance(const, types.CodeType):
out_names |= _extract_code_globals(const)
out_names.update(_extract_code_globals(const))

_extract_code_globals_cache[co] = out_names

Expand Down
33 changes: 31 additions & 2 deletions tests/cloudpickle_test.py
Expand Up @@ -10,6 +10,7 @@
import logging
import math
from operator import itemgetter, attrgetter
import pickletools
import platform
import random
import shutil
Expand Down Expand Up @@ -50,13 +51,15 @@
from cloudpickle.cloudpickle import _lookup_module_and_qualname

from .testutils import subprocess_pickle_echo
from .testutils import subprocess_pickle_string
from .testutils import assert_run_python_script
from .testutils import subprocess_worker

from _cloudpickle_testpkg import relative_imports_factory


_TEST_GLOBAL_VARIABLE = "default_value"
_TEST_GLOBAL_VARIABLE2 = "another_value"


class RaiserOnPickle(object):
Expand Down Expand Up @@ -2095,8 +2098,8 @@ def inner_function():
return _TEST_GLOBAL_VARIABLE
return inner_function

globals_ = cloudpickle.cloudpickle._extract_code_globals(
function_factory.__code__)
globals_ = set(cloudpickle.cloudpickle._extract_code_globals(
function_factory.__code__).keys())
assert globals_ == {'_TEST_GLOBAL_VARIABLE'}

depickled_factory = pickle_depickle(function_factory,
Expand Down Expand Up @@ -2330,6 +2333,32 @@ def __type__(self):
o = MyClass()
pickle_depickle(o, protocol=self.protocol)

@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="Determinism can only be guaranteed for Python 3.7+"
)
def test_deterministic_pickle_bytes_for_function(self):
# Ensure that functions with references to several global names are
# pickled to fixed bytes that do not depend on the PYTHONHASHSEED of
# the Python process.
vals = set()

def func_with_globals():
return _TEST_GLOBAL_VARIABLE + _TEST_GLOBAL_VARIABLE2

for i in range(5):
vals.add(
subprocess_pickle_string(func_with_globals,
protocol=self.protocol,
add_env={"PYTHONHASHSEED": str(i)}))
if len(vals) > 1:
# Print additional debug info on stdout with dis:
for val in vals:
pickletools.dis(val)
pytest.fail(
f"Expected a single deterministic payload, got {len(vals)}/5"
)


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down
30 changes: 24 additions & 6 deletions tests/testutils.py
Expand Up @@ -2,7 +2,6 @@
import os
import os.path as op
import tempfile
import base64
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
from cloudpickle.compat import pickle
from contextlib import contextmanager
Expand Down Expand Up @@ -38,15 +37,16 @@ def _make_cwd_env():
return cloudpickle_repo_folder, env


def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT):
"""Echo function with a child Python process
def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT,
add_env=None):
"""Retrieve pickle string of an object generated by a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> subprocess_pickle_echo([1, 'a', None])
[1, 'a', None]
>>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'
"""
# run then pickle_echo(protocol=protocol) in __main__:
Expand All @@ -56,6 +56,8 @@ def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT):
# which is deprecated in python 3.8
cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)]
cwd, env = _make_cwd_env()
if add_env:
env.update(add_env)
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env,
bufsize=4096)
pickle_string = dumps(input_data, protocol=protocol)
Expand All @@ -67,14 +69,30 @@ def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT):
message = "Subprocess returned %d: " % proc.returncode
message += err.decode('utf-8')
raise RuntimeError(message)
return loads(out)
return out
except TimeoutExpired as e:
proc.kill()
out, err = proc.communicate()
message = u"\n".join([out.decode('utf-8'), err.decode('utf-8')])
raise RuntimeError(message) from e


def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT,
add_env=None):
"""Echo function with a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> subprocess_pickle_echo([1, 'a', None])
[1, 'a', None]
"""
out = subprocess_pickle_string(input_data,
protocol=protocol,
timeout=timeout,
add_env=add_env)
return loads(out)


def _read_all_bytes(stream_in, chunk_size=4096):
all_data = b""
while True:
Expand Down

0 comments on commit afdea9e

Please sign in to comment.