diff --git a/CHANGES.md b/CHANGES.md index 4eda950fd..0270c0386 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,10 @@ dev https://www.python.org/dev/peps/pep-0563/ ([PR #400](https://github.com/cloudpipe/cloudpickle/pull/400)) +- Stricter parametrized type detection heuristics in + _is_parametrized_type_hint to limit false positives. + ([PR #409](https://github.com/cloudpipe/cloudpickle/pull/409)) + 1.6.0 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index d1b38f826..20e9a9550 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -455,15 +455,31 @@ def _extract_class_dict(cls): if sys.version_info[:2] < (3, 7): # pragma: no branch def _is_parametrized_type_hint(obj): - # This is very cheap but might generate false positives. + # This is very cheap but might generate false positives. So try to + # narrow it down is good as possible. + type_module = getattr(type(obj), '__module__', None) + from_typing_extensions = type_module == 'typing_extensions' + from_typing = type_module == 'typing' + # general typing Constructs is_typing = getattr(obj, '__origin__', None) is not None # typing_extensions.Literal - is_literal = getattr(obj, '__values__', None) is not None + is_literal = ( + (getattr(obj, '__values__', None) is not None) + and from_typing_extensions + ) # typing_extensions.Final - is_final = getattr(obj, '__type__', None) is not None + is_final = ( + (getattr(obj, '__type__', None) is not None) + and from_typing_extensions + ) + + # typing.ClassVar + is_classvar = ( + (getattr(obj, '__type__', None) is not None) and from_typing + ) # typing.Union/Tuple for old Python 3.5 is_union = getattr(obj, '__union_params__', None) is not None @@ -472,8 +488,8 @@ def _is_parametrized_type_hint(obj): getattr(obj, '__result__', None) is not None and getattr(obj, '__args__', None) is not None ) - return any((is_typing, is_literal, is_final, is_union, is_tuple, - is_callable)) + return any((is_typing, is_literal, is_final, is_classvar, is_union, + is_tuple, is_callable)) def _create_parametrized_type_hint(origin, args): return origin[args] diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 4e7eac466..845f27962 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -2301,6 +2301,26 @@ def reduce_myclass(x): finally: copyreg.dispatch_table.pop(MyClass) + def test_literal_misdetection(self): + # see https://github.com/cloudpipe/cloudpickle/issues/403 + class MyClass: + @property + def __values__(self): + return () + + o = MyClass() + pickle_depickle(o, protocol=self.protocol) + + def test_final_or_classvar_misdetection(self): + # see https://github.com/cloudpipe/cloudpickle/issues/403 + class MyClass: + @property + def __type__(self): + return int + + o = MyClass() + pickle_depickle(o, protocol=self.protocol) + class Protocol2CloudPickleTest(CloudPickleTest):