diff --git a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py b/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py index 2446ff91fb0562..9fccc9ce472e39 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py +++ b/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import tempfile from absl import app from absl import flags @@ -57,6 +58,10 @@ def train(fine_tuning): model.fit_generator(generator=dataset.batch(1), epochs=5) + # This is testing that a model using a SavedModel can be re-exported again, + # e.g. to catch issues such as b/142231881. + tf.saved_model.save(model, tempfile.mkdtemp()) + def main(argv): del argv diff --git a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py b/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py index e9f251376efb9b..8a0173c8aa7d75 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py +++ b/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import tempfile from absl import app from absl import flags import numpy as np @@ -39,6 +40,10 @@ def main(argv): tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)), initial_state) + # This is testing that a model using a SavedModel can be re-exported again, + # e.g. to catch issues such as b/142231881. + tf.saved_model.save(cell, tempfile.mkdtemp()) + if __name__ == "__main__": app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py b/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py index a21922219ce733..b22102b90d35d3 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py +++ b/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import tempfile from absl import app from absl import flags @@ -55,6 +56,10 @@ def _map_fn(features, labels): model.fit_generator(generator=dataset.batch(10), epochs=5) + # This is testing that a model using a SavedModel can be re-exported again, + # e.g. to catch issues such as b/142231881. + tf.saved_model.save(model, tempfile.mkdtemp()) + def main(argv): del argv diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py b/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py index 9178ff5581f556..ad7dea6ed6e69c 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py +++ b/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import tempfile from absl import app from absl import flags import tensorflow.compat.v2 as tf @@ -40,6 +41,9 @@ def main(argv): sequence_length=10, first_word=tf.constant("")) _ = [d.numpy() for d in decoded] + # This is testing that a model using a SavedModel can be re-exported again, + # e.g. to catch issues such as b/142231881. + tf.saved_model.save(model, tempfile.mkdtemp()) if __name__ == "__main__": app.run(main) diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 88f0f819ea7183..d8f146b38da7bf 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -425,11 +425,13 @@ def _list_functions_for_serialization(self, unused_serialization_cache): # Overwrite this method to avoid the implementation of # base class to re-wrap the polymorphic functions into # another layer of `tf.function`. - return { + functions = { "_create_resource": self._create_resource, "_initialize": self._initialize, - "_destroy_resource": self._destroy_resource, } + if self._destroy_resource: + functions.update(_destroy_resource=self._destroy_resource) + return functions def _call_attribute(instance, *args, **kwargs):