diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py index 21c2e9df2fae9f..21be3677aa8496 100644 --- a/tensorflow/python/saved_model/simple_save_test.py +++ b/tensorflow/python/saved_model/simple_save_test.py @@ -21,7 +21,6 @@ import os from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import loader @@ -32,7 +31,7 @@ class SimpleSaveTest(test.TestCase): - def _init_and_validate_variable(self, sess, variable_name, variable_value): + def _init_and_validate_variable(self, variable_name, variable_value): v = variables.Variable(variable_value, name=variable_name) self.evaluate(variables.global_variables_initializer()) self.assertEqual(variable_value, self.evaluate(v)) @@ -54,50 +53,54 @@ def _check_tensor_info(self, actual_tensor_info, expected_tensor): self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size, expected_tensor.shape[i]) - @test_util.run_deprecated_v1 def testSimpleSave(self): """Test simple_save that uses the default parameters.""" export_dir = os.path.join(test.get_temp_dir(), "test_simple_save") - # Initialize input and output variables and save a prediction graph using - # the default parameters. - with self.session(graph=ops.Graph()) as sess: - var_x = self._init_and_validate_variable(sess, "var_x", 1) - var_y = self._init_and_validate_variable(sess, "var_y", 2) - inputs = {"x": var_x} - outputs = {"y": var_y} - simple_save.simple_save(sess, export_dir, inputs, outputs) - - # Restore the graph with a valid tag and check the global variables and - # signature def map. - with self.session(graph=ops.Graph()) as sess: - graph = loader.load(sess, [tag_constants.SERVING], export_dir) - collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - - # Check value and metadata of the saved variables. - self.assertEqual(len(collection_vars), 2) - self.assertEqual(1, collection_vars[0].eval()) - self.assertEqual(2, collection_vars[1].eval()) - self._check_variable_info(collection_vars[0], var_x) - self._check_variable_info(collection_vars[1], var_y) - - # Check that the appropriate signature_def_map is created with the - # default key and method name, and the specified inputs and outputs. - signature_def_map = graph.signature_def - self.assertEqual(1, len(signature_def_map)) - self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - list(signature_def_map.keys())[0]) - - signature_def = signature_def_map[ - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - self.assertEqual(signature_constants.PREDICT_METHOD_NAME, - signature_def.method_name) - - self.assertEqual(1, len(signature_def.inputs)) - self._check_tensor_info(signature_def.inputs["x"], var_x) - self.assertEqual(1, len(signature_def.outputs)) - self._check_tensor_info(signature_def.outputs["y"], var_y) + # Force the test to run in graph mode. + # This tests a deprecated v1 API that both requires a session and uses + # functionality that does not work with eager tensors (such as + # build_tensor_info as called by predict_signature_def). + with ops.Graph().as_default(): + # Initialize input and output variables and save a prediction graph using + # the default parameters. + with self.session(graph=ops.Graph()) as sess: + var_x = self._init_and_validate_variable("var_x", 1) + var_y = self._init_and_validate_variable("var_y", 2) + inputs = {"x": var_x} + outputs = {"y": var_y} + simple_save.simple_save(sess, export_dir, inputs, outputs) + + # Restore the graph with a valid tag and check the global variables and + # signature def map. + with self.session(graph=ops.Graph()) as sess: + graph = loader.load(sess, [tag_constants.SERVING], export_dir) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + + # Check value and metadata of the saved variables. + self.assertEqual(len(collection_vars), 2) + self.assertEqual(1, collection_vars[0].eval()) + self.assertEqual(2, collection_vars[1].eval()) + self._check_variable_info(collection_vars[0], var_x) + self._check_variable_info(collection_vars[1], var_y) + + # Check that the appropriate signature_def_map is created with the + # default key and method name, and the specified inputs and outputs. + signature_def_map = graph.signature_def + self.assertEqual(1, len(signature_def_map)) + self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + list(signature_def_map.keys())[0]) + + signature_def = signature_def_map[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertEqual(signature_constants.PREDICT_METHOD_NAME, + signature_def.method_name) + + self.assertEqual(1, len(signature_def.inputs)) + self._check_tensor_info(signature_def.inputs["x"], var_x) + self.assertEqual(1, len(signature_def.outputs)) + self._check_tensor_info(signature_def.outputs["y"], var_y) if __name__ == "__main__":