diff --git a/tensorflow/python/checkpoint/async_checkpoint_helper.py b/tensorflow/python/checkpoint/async_checkpoint_helper.py index 45868da149a0dc..aa801c4bf9a973 100644 --- a/tensorflow/python/checkpoint/async_checkpoint_helper.py +++ b/tensorflow/python/checkpoint/async_checkpoint_helper.py @@ -263,10 +263,10 @@ def _ensure_initialized(self): # custom __getattr__ code, see b/152031870 for context. for t in all_trackables: # Special case 1: TPU Embedding, populate object_map here - # Special case 1: Handle TPU Embedding by addnig a dummy instance to the - # object map. Also add TPUEmbedding to separate list for special handling - # with values copy. - if hasattr(t, _TPU_EMBEDDING_ATTR): + # Special case 1: Handle TPU Embedding by addnig a dummy instance to the + # object map. Also add TPUEmbedding to separate list for special handling + # with values copy. + if hasattr(type(t), _TPU_EMBEDDING_ATTR): self._handle_tpu_embedding(t) # Special case 2: handle slot variables. The object_map is populated later # when the variable values are being copied to host CPU for the first @@ -414,9 +414,9 @@ def _handle_tpu_embedding(self, tpu_embedding): Raises: AttributeError: if the input trackable is not TPUEmbedding type. """ - if not hasattr( - tpu_embedding, _TPU_EMBEDDING_ATTR - ) or not callable(tpu_embedding._create_copy_for_async_checkpoint): # pylint: disable=protected-access + if not hasattr(type(tpu_embedding), _TPU_EMBEDDING_ATTR) or not callable( + tpu_embedding._create_copy_for_async_checkpoint # pylint: disable=protected-access + ): raise AttributeError( "Expecting TPUEmbedding type; got %s" % type(tpu_embedding) )