diff --git a/sortedcontainers/sorteddict.py b/sortedcontainers/sorteddict.py index 164b58b8..910f2608 100644 --- a/sortedcontainers/sorteddict.py +++ b/sortedcontainers/sorteddict.py @@ -19,6 +19,8 @@ import sys import warnings +from itertools import chain + from .sortedlist import SortedList, recursive_repr from .sortedset import SortedSet @@ -27,9 +29,11 @@ ############################################################################### try: - from collections.abc import ItemsView, KeysView, ValuesView, Sequence + from collections.abc import ( + ItemsView, KeysView, Mapping, ValuesView, Sequence + ) except ImportError: - from collections import ItemsView, KeysView, ValuesView, Sequence + from collections import ItemsView, KeysView, Mapping, ValuesView, Sequence ############################################################################### # END Python 2/3 Shims @@ -298,6 +302,25 @@ def __setitem__(self, key, value): _setitem = __setitem__ + def __or__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + items = chain(self.items(), other.items()) + return self.__class__(self._key, items) + + + def __ror__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + items = chain(other.items(), self.items()) + return self.__class__(self._key, items) + + + def __ior__(self, other): + self._update(other) + return self + + def copy(self): """Return a shallow copy of the sorted dict. diff --git a/tests/test_coverage_sorteddict.py b/tests/test_coverage_sorteddict.py index 5eaa8307..9b62a7f3 100644 --- a/tests/test_coverage_sorteddict.py +++ b/tests/test_coverage_sorteddict.py @@ -491,3 +491,38 @@ def test_ref_counts(): del temp del_count = len(gc.get_objects()) assert start_count == del_count + +class CustomOr: + def __or__(self, other): + return NotImplemented + + def __ror__(self, other): + return self + +def test_or(): + mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)] + temp1 = SortedDict(mapping[:13]) + temp2 = SortedDict(mapping[13:]) + temp3 = temp1 | temp2 + assert temp3 == dict(mapping) + +def test_or_not_implemented(): + SortedDict() | CustomOr() + +def test_ror(): + mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)] + temp1 = dict(mapping[:13]) + temp2 = SortedDict(mapping[13:]) + temp3 = temp1 | temp2 + assert temp3 == dict(mapping) + +def test_ror_not_implemented(): + with pytest.raises(TypeError): + CustomOr() | SortedDict() + +def test_ior(): + mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)] + temp1 = SortedDict(mapping[:13]) + temp2 = SortedDict(mapping[13:]) + temp1 |= temp2 + assert temp1 == dict(mapping)