Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check for main module in reducer override #8455

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 38 additions & 48 deletions distributed/protocol/pickle.py
Expand Up @@ -4,6 +4,8 @@
import io
import logging
import pickle
from copyreg import dispatch_table
from types import FunctionType

import cloudpickle
from packaging.version import parse as parse_version
Expand All @@ -19,32 +21,32 @@

class _DaskPickler(pickle.Pickler):
def reducer_override(self, obj):
# For some objects this causes segfaults otherwise, see
# https://github.com/dask/distributed/pull/7564#issuecomment-1438727339
if _always_use_pickle_for(obj):
return NotImplemented
try:
serialize = dask_serialize.dispatch(type(obj))
deserialize = dask_deserialize.dispatch(type(obj))
return deserialize, serialize(obj)
except TypeError:
return NotImplemented


def _always_use_pickle_for(x):
mod, _, _ = x.__class__.__module__.partition(".")
if mod == "numpy":
import numpy as np

return isinstance(x, np.ndarray)
elif mod == "pandas":
import pandas as pd

return isinstance(x, pd.core.generic.NDFrame)
elif mod == "builtins":
return isinstance(x, (str, bytes))
else:
return False
mod = inspect.getmodule(type(obj))

# If a thing is local scoped, use cloudpickle
# This check is not guaranteed and evaluates false positively for
# dynamically created types, e.g. numpy scalars
if getattr(mod, type(obj).__name__, None) is not None:
return pickle.loads, (cloudpickle.dumps(obj),)
if isinstance(obj, FunctionType):
module_name = pickle.whichmodule(obj, None)
if (
module_name == "__main__"
or CLOUDPICKLE_GE_20
and module_name in cloudpickle.list_registry_pickle_by_value()
):
return pickle.loads, (cloudpickle.dumps(obj),)
elif type(obj) is memoryview:
return memoryview, (pickle.PickleBuffer(obj),)

Check warning on line 40 in distributed/protocol/pickle.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/pickle.py#L40

Added line #L40 was not covered by tests
elif type(obj) not in dispatch_table:
try:
serialize = dask_serialize.dispatch(type(obj))
deserialize = dask_deserialize.dispatch(type(obj))
rv = deserialize, serialize(obj)
return rv

Check warning on line 46 in distributed/protocol/pickle.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/pickle.py#L44-L46

Added lines #L44 - L46 were not covered by tests
except Exception:
return NotImplemented
return NotImplemented


def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
Expand All @@ -56,31 +58,19 @@
"""
buffers = []
dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL}

if dump_kwargs["protocol"] >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append

try:
try:
result = pickle.dumps(x, **dump_kwargs)
except Exception:
f = io.BytesIO()
pickler = _DaskPickler(f, **dump_kwargs)
buffers.clear()
pickler.dump(x)
result = f.getvalue()

if not _always_use_pickle_for(x) and (
CLOUDPICKLE_GE_20
and getattr(inspect.getmodule(x), "__name__", None)
in cloudpickle.list_registry_pickle_by_value()
or (
len(result) < 1000
# Do this very last since it's expensive
and b"__main__" in result
)
):
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception:
f = io.BytesIO()
pickler = _DaskPickler(f, **dump_kwargs)
pickler.dump(x)
result = f.getvalue()
except Exception as exc:
import traceback

traceback.print_tb(exc.__traceback__)
try:
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
Expand Down
24 changes: 23 additions & 1 deletion distributed/protocol/tests/test_pickle.py
Expand Up @@ -20,7 +20,7 @@
loads,
)
from distributed.protocol.serialize import dask_deserialize, dask_serialize
from distributed.utils_test import save_sys_modules
from distributed.utils_test import popen, save_sys_modules


class MemoryviewHolder:
Expand Down Expand Up @@ -278,3 +278,25 @@ def test_nopickle_nested():
finally:
del dask_serialize._lookup[NoPickle]
del dask_deserialize._lookup[NoPickle]


@pytest.mark.slow()
def test_pickle_functions_in_main(tmp_path):
script = """
from dask.distributed import Client

if __name__ == "__main__":
with Client(n_workers=1) as client:
def func(df):
return (df + 5)
client.submit(func, 5).result()
print("success")
"""
with open(tmp_path / "script.py", mode="w") as f:
f.write(script)
with popen([sys.executable, tmp_path / "script.py"], capture_output=True) as proc:
out, _ = proc.communicate(timeout=60)

lines = out.decode("utf-8").split("\n")

assert "success" in lines