From 8862e70f2a63356c47b26eb1d19168e08faeab3b Mon Sep 17 00:00:00 2001 From: Zecong Hu Date: Fri, 17 Sep 2021 22:27:24 -0400 Subject: [PATCH] Fix #440: Incorrect pickles for subclasses of generic classes --- cloudpickle/cloudpickle.py | 6 +++++- tests/cloudpickle_test.py | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 80ccdea6..6fb4462d 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -910,8 +910,12 @@ def _typevar_reduce(obj): def _get_bases(typ): - if hasattr(typ, '__orig_bases__'): + if '__orig_bases__' in getattr(typ, '__dict__', {}): # For generic types (see PEP 560) + # Note that simply checking `hasattr(typ, '__orig_bases__')` is not + # correct. Subclasses of a fully-parameterized generic class does not + # have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')` + # will return True because it's defined in the base class. bases_attr = '__orig_bases__' else: # For regular class objects diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index b96fb155..bce0c368 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -2223,6 +2223,47 @@ def check_generic(generic, origin, type_value, use_args): assert check_generic(C[int], C, int, use_args) == "ok" assert worker.run(check_generic, C[int], C, int, use_args) == "ok" + def test_generic_subclass(self): + T = typing.TypeVar('T') + + class Base(typing.Generic[T]): + pass + + class DerivedAny(Base): + pass + + class LeafAny(DerivedAny): + pass + + class DerivedInt(Base[int]): + pass + + class LeafInt(DerivedInt): + pass + + class DerivedT(Base[T]): + pass + + class LeafT(DerivedT[T]): + pass + + klasses = [ + Base, DerivedAny, LeafAny, DerivedInt, LeafInt, DerivedT, LeafT + ] + for klass in klasses: + assert pickle_depickle(klass, protocol=self.protocol) is klass + + with subprocess_worker(protocol=self.protocol) as worker: + + def check_mro(klass, expected_mro): + assert klass.mro() == expected_mro + return "ok" + + for klass in klasses: + mro = klass.mro() + assert check_mro(klass, mro) + assert worker.run(check_mro, klass, mro) == "ok" + def test_locally_defined_class_with_type_hints(self): with subprocess_worker(protocol=self.protocol) as worker: for type_ in _all_types_to_test():