From 1bf52774afca43e3878b5f8256f437484f218ea2 Mon Sep 17 00:00:00 2001 From: Zecong Hu Date: Wed, 16 Feb 2022 18:43:48 -0500 Subject: [PATCH] Fix #440: Incorrect pickles for subclasses of generic classes (#448) * Fix #440: Incorrect pickles for subclasses of generic classes * Update CHANGES * Empty Commit to trigger CI Co-authored-by: Pierre Glaser --- CHANGES.md | 3 +++ cloudpickle/cloudpickle.py | 6 +++++- tests/cloudpickle_test.py | 41 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 972cf3f0..8a97012a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,9 @@ and `abc.abstractstaticmethod`. ([PR #450](https://github.com/cloudpipe/cloudpickle/pull/450)) +- Support for pickling subclasses of generic classes. + ([PR #448](https://github.com/cloudpipe/cloudpickle/pull/448)) + 2.0.0 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 80ccdea6..6fb4462d 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -910,8 +910,12 @@ def _typevar_reduce(obj): def _get_bases(typ): - if hasattr(typ, '__orig_bases__'): + if '__orig_bases__' in getattr(typ, '__dict__', {}): # For generic types (see PEP 560) + # Note that simply checking `hasattr(typ, '__orig_bases__')` is not + # correct. Subclasses of a fully-parameterized generic class does not + # have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')` + # will return True because it's defined in the base class. bases_attr = '__orig_bases__' else: # For regular class objects diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index f356681d..181a859e 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -2335,6 +2335,47 @@ def check_generic(generic, origin, type_value, use_args): assert check_generic(C[int], C, int, use_args) == "ok" assert worker.run(check_generic, C[int], C, int, use_args) == "ok" + def test_generic_subclass(self): + T = typing.TypeVar('T') + + class Base(typing.Generic[T]): + pass + + class DerivedAny(Base): + pass + + class LeafAny(DerivedAny): + pass + + class DerivedInt(Base[int]): + pass + + class LeafInt(DerivedInt): + pass + + class DerivedT(Base[T]): + pass + + class LeafT(DerivedT[T]): + pass + + klasses = [ + Base, DerivedAny, LeafAny, DerivedInt, LeafInt, DerivedT, LeafT + ] + for klass in klasses: + assert pickle_depickle(klass, protocol=self.protocol) is klass + + with subprocess_worker(protocol=self.protocol) as worker: + + def check_mro(klass, expected_mro): + assert klass.mro() == expected_mro + return "ok" + + for klass in klasses: + mro = klass.mro() + assert check_mro(klass, mro) + assert worker.run(check_mro, klass, mro) == "ok" + def test_locally_defined_class_with_type_hints(self): with subprocess_worker(protocol=self.protocol) as worker: for type_ in _all_types_to_test():