diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-21 18:22:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 18:25:59 -0700 |
commit | 708b30f4cb82271bb28cb70a1e0c89a1933f5b64 (patch) | |
tree | 22470a9314f7f4225b6d08170a3d7ea91b0216a1 /tensorflow/python/saved_model | |
parent | d0cac47a767dd972516f75ce57f0d6185e3b6514 (diff) |
Move from deprecated self.test_session() to self.session() when a graph is set.
self.test_session() has been deprecated in cl/208545396 as its behavior confuses readers of the test. Moving to self.session() instead.
PiperOrigin-RevId: 209696110
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/loader_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 170 | ||||
-rw-r--r-- | tensorflow/python/saved_model/simple_save_test.py | 4 |
3 files changed, 96 insertions, 96 deletions
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py index 9a0b276a4b..b7e217a35b 100644 --- a/tensorflow/python/saved_model/loader_test.py +++ b/tensorflow/python/saved_model/loader_test.py @@ -79,13 +79,13 @@ class SavedModelLoaderTest(test.TestCase): def test_load_function(self): loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo_graph"]) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader2.load(sess, ["foo_graph"]) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) @@ -101,7 +101,7 @@ class SavedModelLoaderTest(test.TestCase): with self.assertRaises(KeyError): graph.get_tensor_by_name("z:0") - with self.test_session(graph=graph) as sess: + with self.session(graph=graph) as sess: # Check that x and y are not initialized with self.assertRaises(errors.FailedPreconditionError): sess.run(x) @@ -110,7 +110,7 @@ class SavedModelLoaderTest(test.TestCase): def test_load_with_import_scope(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: saver, _ = loader.load_graph( sess.graph, ["foo_graph"], import_scope="baz") @@ -126,14 +126,14 @@ class SavedModelLoaderTest(test.TestCase): # Test combined load function. loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo_graph"], import_scope="baa") self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) def test_restore_variables(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: x = variables.Variable(0, name="x") y = variables.Variable(0, name="y") z = x * y @@ -151,7 +151,7 @@ class SavedModelLoaderTest(test.TestCase): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) graph = ops.Graph() saver, _ = loader.load_graph(graph, ["foo_graph"]) - with self.test_session(graph=graph) as sess: + with self.session(graph=graph) as sess: loader.restore_variables(sess, saver) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) @@ -203,12 +203,12 @@ class SavedModelLoaderTest(test.TestCase): builder.save() loader = loader_impl.SavedModelLoader(path) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: saver, _ = loader.load_graph(sess.graph, ["foo_graph"]) self.assertFalse(variables._all_saveable_objects()) self.assertIsNotNone(saver) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo_graph"]) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 00b669fc97..49d52d3bee 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -97,7 +97,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name) def _validate_inputs_tensor_info_fail(self, builder, tensor_info): - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) foo_signature = signature_def_utils.build_signature_def({ @@ -110,7 +110,7 @@ class SavedModelTest(test.TestCase): signature_def_map={"foo_key": foo_signature}) def _validate_inputs_tensor_info_accept(self, builder, tensor_info): - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) foo_signature = signature_def_utils.build_signature_def({ @@ -121,7 +121,7 @@ class SavedModelTest(test.TestCase): signature_def_map={"foo_key": foo_signature}) def _validate_outputs_tensor_info_fail(self, builder, tensor_info): - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) foo_signature = signature_def_utils.build_signature_def( @@ -133,7 +133,7 @@ class SavedModelTest(test.TestCase): signature_def_map={"foo_key": foo_signature}) def _validate_outputs_tensor_info_accept(self, builder, tensor_info): - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) foo_signature = signature_def_utils.build_signature_def( @@ -153,7 +153,7 @@ class SavedModelTest(test.TestCase): def testBadSavedModelFileFormat(self): export_dir = self._get_export_dir("test_bad_saved_model_file_format") # Attempt to load a SavedModel from an export directory that does not exist. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: with self.assertRaisesRegexp(IOError, "SavedModel file does not exist at: %s" % export_dir): @@ -164,7 +164,7 @@ class SavedModelTest(test.TestCase): path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) with open(path_to_pb, "w") as f: f.write("invalid content") - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % constants.SAVED_MODEL_FILENAME_PB): loader.load(sess, ["foo"], export_dir) @@ -178,7 +178,7 @@ class SavedModelTest(test.TestCase): constants.SAVED_MODEL_FILENAME_PBTXT) with open(path_to_pbtxt, "w") as f: f.write("invalid content") - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % constants.SAVED_MODEL_FILENAME_PBTXT): loader.load(sess, ["foo"], export_dir) @@ -187,7 +187,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_verify_session_graph_usage") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) @@ -209,12 +209,12 @@ class SavedModelTest(test.TestCase): # Expect an assertion error since add_meta_graph_and_variables() should be # invoked before any add_meta_graph() calls. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"]) # Expect an assertion error for multiple calls of # add_meta_graph_and_variables() since weights should be saved exactly once. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, ["bar"]) self.assertRaises(AssertionError, builder.add_meta_graph_and_variables, @@ -227,35 +227,35 @@ class SavedModelTest(test.TestCase): # Graph with a single variable. SavedModel invoked to: # - add with weights. # - a single tag (from predefined constants). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) # Graph that updates the single variable. SavedModel invoked to: # - simply add the model (weights are not updated). # - a single tag (from predefined constants). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) builder.add_meta_graph([tag_constants.SERVING]) # Graph that updates the single variable. SavedModel invoked to: # - simply add the model (weights are not updated). # - multiple tags (from predefined constants). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 45) builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) # Graph that updates the single variable. SavedModel invoked to: # - simply add the model (weights are not updated). # - multiple tags (from predefined constants for serving on TPU). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 45) builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU]) # Graph that updates the single variable. SavedModel is invoked: # - to add the model (weights are not updated). # - multiple custom tags. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 44) builder.add_meta_graph(["foo", "bar"]) @@ -263,49 +263,49 @@ class SavedModelTest(test.TestCase): builder.save() # Restore the graph with a single predefined tag whose variables were saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Restore the graph with a single predefined tag whose variables were not # saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Restore the graph with multiple predefined tags whose variables were not # saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Restore the graph with multiple predefined tags (for serving on TPU) # whose variables were not saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Restore the graph with multiple tags. Provide duplicate tags to test set # semantics. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo", "bar", "foo"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Try restoring a graph with a non-existent tag. This should yield a runtime # error. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"], export_dir) # Try restoring a graph where a subset of the tags match. Since tag matching # for meta graph defs follows "all" semantics, this should yield a runtime # error. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"], export_dir) @@ -315,7 +315,7 @@ class SavedModelTest(test.TestCase): # Graph with two variables. SavedModel invoked to: # - add with weights. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v1", 1) self._init_and_validate_variable(sess, "v2", 2) builder.add_meta_graph_and_variables(sess, ["foo"]) @@ -323,14 +323,14 @@ class SavedModelTest(test.TestCase): # Graph with a single variable (subset of the variables from the previous # graph whose weights were saved). SavedModel invoked to: # - simply add the model (weights are not updated). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v2", 3) builder.add_meta_graph(["bar"]) # Graph with a single variable (disjoint set of variables from the previous # graph whose weights were saved). SavedModel invoked to: # - simply add the model (weights are not updated). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v3", 4) builder.add_meta_graph(["baz"]) @@ -338,7 +338,7 @@ class SavedModelTest(test.TestCase): builder.save() # Restore the graph with tag "foo", whose variables were saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertEqual(len(collection_vars), 2) @@ -348,7 +348,7 @@ class SavedModelTest(test.TestCase): # Restore the graph with tag "bar", whose variables were not saved. Only the # subset of the variables added to the graph will be restored with the # checkpointed value. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["bar"], export_dir) collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertEqual(len(collection_vars), 1) @@ -357,7 +357,7 @@ class SavedModelTest(test.TestCase): # Try restoring the graph with tag "baz", whose variables were not saved. # Since this graph has a disjoint set of variables from the set that was # saved, this should raise an error. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"], export_dir) @@ -366,12 +366,12 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with no variables. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: constant_5_name = constant_op.constant(5.0).name builder.add_meta_graph_and_variables(sess, ["foo"]) # Second graph with no variables - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: constant_6_name = constant_op.constant(6.0).name builder.add_meta_graph(["bar"]) @@ -379,7 +379,7 @@ class SavedModelTest(test.TestCase): builder.save() # Restore the graph with tag "foo". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) # Read the constant a from the graph. a = ops.get_default_graph().get_tensor_by_name(constant_5_name) @@ -388,7 +388,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(30.0, sess.run(c)) # Restore the graph with tag "bar". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["bar"], export_dir) # Read the constant a from the graph. a = ops.get_default_graph().get_tensor_by_name(constant_6_name) @@ -402,7 +402,7 @@ class SavedModelTest(test.TestCase): # Graph with a single variable. SavedModel invoked to: # - add with weights. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, ["foo"]) @@ -410,7 +410,7 @@ class SavedModelTest(test.TestCase): builder.save(as_text=True) # Restore the graph with tag "foo", whose variables were saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) @@ -426,13 +426,13 @@ class SavedModelTest(test.TestCase): # Graph with a single variable. SavedModel invoked to: # - add with weights. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, ["foo"]) # Graph with the same single variable. SavedModel invoked to: # - simply add the model (weights are not updated). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) builder.add_meta_graph(["bar"]) @@ -440,13 +440,13 @@ class SavedModelTest(test.TestCase): builder.save(as_text=True) # Restore the graph with tag "foo", whose variables were saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) # Restore the graph with tag "bar", whose variables were not saved. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["bar"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) @@ -457,7 +457,7 @@ class SavedModelTest(test.TestCase): # Graph with a single variable added to a collection. SavedModel invoked to: # - add with weights. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: v = variables.Variable(42, name="v") ops.add_to_collection("foo_vars", v) sess.run(variables.global_variables_initializer()) @@ -467,7 +467,7 @@ class SavedModelTest(test.TestCase): # Graph with the same single variable added to a different collection. # SavedModel invoked to: # - simply add the model (weights are not updated). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: v = variables.Variable(43, name="v") ops.add_to_collection("bar_vars", v) sess.run(variables.global_variables_initializer()) @@ -480,7 +480,7 @@ class SavedModelTest(test.TestCase): # Restore the graph with tag "foo", whose variables were saved. The # collection 'foo_vars' should contain a single element. The collection # 'bar_vars' should not be found. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) collection_foo_vars = ops.get_collection("foo_vars") self.assertEqual(len(collection_foo_vars), 1) @@ -493,7 +493,7 @@ class SavedModelTest(test.TestCase): # reflect the new collection. The value of the variable in the # collection-def corresponds to the saved value (from the previous graph # with tag "foo"). - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["bar"], export_dir) collection_bar_vars = ops.get_collection("bar_vars") self.assertEqual(len(collection_bar_vars), 1) @@ -507,7 +507,7 @@ class SavedModelTest(test.TestCase): # Graph with a single variable and a single entry in the signature def map. # SavedModel is invoked to add with weights. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build and populate an empty SignatureDef for testing. foo_signature = signature_def_utils.build_signature_def(dict(), @@ -517,7 +517,7 @@ class SavedModelTest(test.TestCase): # Graph with the same single variable and multiple entries in the signature # def map. No weights are saved by SavedModel. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) # Build and populate a different SignatureDef for testing. bar_signature = signature_def_utils.build_signature_def(dict(), @@ -539,7 +539,7 @@ class SavedModelTest(test.TestCase): # Restore the graph with tag "foo". The single entry in the SignatureDef map # corresponding to "foo_key" should exist. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) @@ -551,7 +551,7 @@ class SavedModelTest(test.TestCase): # Restore the graph with tag "bar". The SignatureDef map should have two # entries. One corresponding to "bar_key" and another corresponding to the # new value of "foo_key". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: bar_graph = loader.load(sess, ["bar"], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) @@ -610,7 +610,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build an asset collection. @@ -628,7 +628,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "hello42.txt", "foo bar baz", @@ -643,7 +643,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_assets_name_collision_diff_file") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) asset_collection = self._build_asset_collection( @@ -660,7 +660,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "hello42.txt", "foo bar bak", @@ -674,7 +674,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_assets_name_collision_same_path") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) asset_collection = self._build_asset_collection( @@ -689,7 +689,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "hello42.txt", "foo bar baz", @@ -709,7 +709,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_assets_name_collision_same_file") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) asset_collection = self._build_asset_collection( @@ -726,7 +726,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "hello42.txt", "foo bar baz", @@ -746,7 +746,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_assets_name_collision_many_files") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) for i in range(5): @@ -761,7 +761,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) for i in range(1, 5): idx = str(i) @@ -778,7 +778,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_main_op") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -801,7 +801,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual(1, ops.get_collection("v")[0].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval()) @@ -813,7 +813,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_legacy_init_op") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -835,7 +835,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual(1, ops.get_collection("v")[0].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval()) @@ -858,7 +858,7 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) g = ops.Graph() - with self.test_session(graph=g) as sess: + with self.session(graph=g) as sess: # Initialize variable `v1` to 1. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -887,7 +887,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_train_op") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -905,7 +905,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual(3, ops.get_collection("v")[0].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval()) @@ -916,7 +916,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_train_op_group") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -934,7 +934,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertEqual(1, ops.get_collection("v")[0].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval()) @@ -945,7 +945,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_train_op_after_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -964,12 +964,12 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) self.assertIsInstance( ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["pre_foo"], export_dir) self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) @@ -977,7 +977,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build an asset collection specific to `foo` graph. @@ -988,7 +988,7 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], assets_collection=asset_collection) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build an asset collection specific to `bar` graph. @@ -1002,14 +1002,14 @@ class SavedModelTest(test.TestCase): builder.save() # Check assets restored for graph with tag "foo". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "foo.txt", "content_foo", "asset_file_tensor:0") # Check assets restored for graph with tag "bar". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: bar_graph = loader.load(sess, ["bar"], export_dir) self._validate_asset_collection(export_dir, bar_graph.collection_def, "bar.txt", "content_bar", @@ -1019,7 +1019,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_duplicate_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build an asset collection with `foo.txt` that has `foo` specific @@ -1031,7 +1031,7 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], assets_collection=asset_collection) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build an asset collection with `foo.txt` that has `bar` specific @@ -1046,14 +1046,14 @@ class SavedModelTest(test.TestCase): builder.save() # Check assets restored for graph with tag "foo". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: foo_graph = loader.load(sess, ["foo"], export_dir) self._validate_asset_collection(export_dir, foo_graph.collection_def, "foo.txt", "content_foo", "asset_file_tensor:0") # Check assets restored for graph with tag "bar". - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: bar_graph = loader.load(sess, ["bar"], export_dir) # Validate the assets for `bar` graph. `foo.txt` should contain the @@ -1139,7 +1139,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_custom_saver") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: variables.Variable(1, name="v1") sess.run(variables.global_variables_initializer()) custom_saver = training.Saver(name="my_saver") @@ -1149,7 +1149,7 @@ class SavedModelTest(test.TestCase): builder.save() with ops.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: + with self.session(graph=graph) as sess: saved_graph = loader.load(sess, ["tag"], export_dir) graph_ops = [x.name for x in graph.get_operations()] self.assertTrue("my_saver/restore_all" in graph_ops) @@ -1161,7 +1161,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_no_custom_saver") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: variables.Variable(1, name="v1") sess.run(variables.global_variables_initializer()) training.Saver(name="my_saver") @@ -1171,7 +1171,7 @@ class SavedModelTest(test.TestCase): builder.save() with ops.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: + with self.session(graph=graph) as sess: saved_graph = loader.load(sess, ["tag"], export_dir) graph_ops = [x.name for x in graph.get_operations()] self.assertTrue("my_saver/restore_all" in graph_ops) @@ -1183,7 +1183,7 @@ class SavedModelTest(test.TestCase): export_dir = self._get_export_dir("test_multiple_custom_savers") builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: variables.Variable(1, name="v1") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, ["tag_0"]) @@ -1199,7 +1199,7 @@ class SavedModelTest(test.TestCase): def _validate_custom_saver(tag_name, saver_name): with ops.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: + with self.session(graph=graph) as sess: saved_graph = loader.load(sess, [tag_name], export_dir) self.assertEqual( saved_graph.saver_def.restore_op_name, @@ -1214,7 +1214,7 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) # Build a SavedModel with a variable, an asset, and a constant tensor. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) asset_collection = self._build_asset_collection("foo.txt", "content_foo", "asset_file_tensor") @@ -1228,7 +1228,7 @@ class SavedModelTest(test.TestCase): # Save the SavedModel to disk. builder.save() - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: # Restore the SavedModel under an import_scope in a new graph/session. graph_proto = loader.load( sess, ["tag_name"], export_dir, import_scope="scope_name") @@ -1281,7 +1281,7 @@ class SavedModelTest(test.TestCase): # Restore the graph with a single predefined tag whose variables were saved # without any device information. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py index b2fa40d4f1..18f82daada 100644 --- a/tensorflow/python/saved_model/simple_save_test.py +++ b/tensorflow/python/saved_model/simple_save_test.py @@ -60,7 +60,7 @@ class SimpleSaveTest(test.TestCase): # Initialize input and output variables and save a prediction graph using # the default parameters. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: var_x = self._init_and_validate_variable(sess, "var_x", 1) var_y = self._init_and_validate_variable(sess, "var_y", 2) inputs = {"x": var_x} @@ -69,7 +69,7 @@ class SimpleSaveTest(test.TestCase): # Restore the graph with a valid tag and check the global variables and # signature def map. - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: graph = loader.load(sess, [tag_constants.SERVING], export_dir) collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |