Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow pickling / depickling of odict_keys, odict_values, odict_items #429

Merged
merged 5 commits into from Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Expand Up @@ -17,6 +17,10 @@ dev
_is_parametrized_type_hint to limit false positives.
([PR #409](https://github.com/cloudpipe/cloudpickle/pull/409))

- Support pickling / depickling of OrderedDict KeysView, ValuesView, and
ItemsView, following similar strategy for vanilla Python dictionaries.
([PR #423](https://github.com/cloudpipe/cloudpickle/pull/423))

- Suppressed a source of non-determinism when pickling dynamically defined
functions and handles the deprecation of co_lnotab in Python 3.10+.
([PR #428](https://github.com/cloudpipe/cloudpickle/pull/428))
Expand Down
22 changes: 16 additions & 6 deletions cloudpickle/cloudpickle.py
Expand Up @@ -55,6 +55,7 @@
import warnings

from .compat import pickle
from collections import OrderedDict
from typing import Generic, Union, Tuple, Callable
from pickle import _getattribute
from importlib._bootstrap import _find_spec
Expand Down Expand Up @@ -855,13 +856,22 @@ def _get_bases(typ):
return getattr(typ, bases_attr)


def _make_dict_keys(obj):
return dict.fromkeys(obj).keys()
def _make_dict_keys(obj, is_ordered=False):
if is_ordered:
return OrderedDict.fromkeys(obj).keys()
else:
return dict.fromkeys(obj).keys()


def _make_dict_values(obj):
return {i: _ for i, _ in enumerate(obj)}.values()
def _make_dict_values(obj, is_ordered=False):
if is_ordered:
return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
else:
return {i: _ for i, _ in enumerate(obj)}.values()


def _make_dict_items(obj):
return obj.items()
def _make_dict_items(obj, is_ordered=False):
if is_ordered:
return OrderedDict(obj).items()
else:
return obj.items()
23 changes: 22 additions & 1 deletion cloudpickle/cloudpickle_fast.py
Expand Up @@ -23,7 +23,7 @@
import typing

from enum import Enum
from collections import ChainMap
from collections import ChainMap, OrderedDict

from .compat import pickle, Pickler
from .cloudpickle import (
Expand Down Expand Up @@ -437,6 +437,24 @@ def _dict_items_reduce(obj):
return _make_dict_items, (dict(obj), )


def _odict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), True)


def _odict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), True)


def _odict_items_reduce(obj):
return _make_dict_items, (dict(obj), True)


# COLLECTIONS OF OBJECTS STATE SETTERS
# ------------------------------------
# state setters are called at unpickling time, once the object is created and
Expand Down Expand Up @@ -513,6 +531,9 @@ class CloudPickler(Pickler):
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
_dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce
_dispatch_table[type(OrderedDict().values())] = _odict_values_reduce
_dispatch_table[type(OrderedDict().items())] = _odict_items_reduce


dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
Expand Down
18 changes: 18 additions & 0 deletions tests/cloudpickle_test.py
Expand Up @@ -229,6 +229,24 @@ def test_dict_items(self):
self.assertEqual(results, items)
assert isinstance(results, _collections_abc.dict_items)

def test_odict_keys(self):
keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys()
results = pickle_depickle(keys)
self.assertEqual(results, keys)
assert type(keys) == type(results)

def test_odict_values(self):
values = collections.OrderedDict([("a", 1), ("b", 2)]).values()
results = pickle_depickle(values)
self.assertEqual(list(results), list(values))
assert type(values) == type(results)

def test_odict_items(self):
items = collections.OrderedDict([("a", 1), ("b", 2)]).items()
results = pickle_depickle(items)
self.assertEqual(results, items)
assert type(items) == type(results)

def test_sliced_and_non_contiguous_memoryview(self):
buffer_obj = memoryview(b"Hello!" * 3)[2:15:2]
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
Expand Down