Skip to content

Commit

Permalink
Merge pull request #33262 from tensorflow/ggadde-1-15-cp2
Browse files Browse the repository at this point in the history
[r1.15 CherryPick]: Add saving of loaded/trained compatibility models in test and fix a c…
  • Loading branch information
goldiegadde committed Oct 14, 2019
2 parents 49c154e + 8d71a87 commit 46f50ff
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 2 deletions.
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags

Expand Down Expand Up @@ -57,6 +58,10 @@ def train(fine_tuning):

model.fit_generator(generator=dataset.batch(1), epochs=5)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())


def main(argv):
del argv
Expand Down
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags
import numpy as np
Expand All @@ -39,6 +40,10 @@ def main(argv):
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
initial_state)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(cell, tempfile.mkdtemp())


if __name__ == "__main__":
app.run(main)
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags

Expand Down Expand Up @@ -55,6 +56,10 @@ def _map_fn(features, labels):

model.fit_generator(generator=dataset.batch(10), epochs=5)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())


def main(argv):
del argv
Expand Down
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
Expand All @@ -40,6 +41,9 @@ def main(argv):
sequence_length=10, first_word=tf.constant("<S>"))
_ = [d.numpy() for d in decoded]

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())

if __name__ == "__main__":
app.run(main)
6 changes: 4 additions & 2 deletions tensorflow/python/saved_model/load.py
Expand Up @@ -425,11 +425,13 @@ def _list_functions_for_serialization(self, unused_serialization_cache):
# Overwrite this method to avoid the implementation of
# base class to re-wrap the polymorphic functions into
# another layer of `tf.function`.
return {
functions = {
"_create_resource": self._create_resource,
"_initialize": self._initialize,
"_destroy_resource": self._destroy_resource,
}
if self._destroy_resource:
functions.update(_destroy_resource=self._destroy_resource)
return functions


def _call_attribute(instance, *args, **kwargs):
Expand Down

0 comments on commit 46f50ff

Please sign in to comment.