aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 18:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 18:25:59 -0700
commit708b30f4cb82271bb28cb70a1e0c89a1933f5b64 (patch)
tree22470a9314f7f4225b6d08170a3d7ea91b0216a1 /tensorflow/python/saved_model
parentd0cac47a767dd972516f75ce57f0d6185e3b6514 (diff)
Move from deprecated self.test_session() to self.session() when a graph is set.
self.test_session() has been deprecated in cl/208545396 as its behavior confuses readers of the test. Moving to self.session() instead. PiperOrigin-RevId: 209696110
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/loader_test.py18
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py170
-rw-r--r--tensorflow/python/saved_model/simple_save_test.py4
3 files changed, 96 insertions, 96 deletions
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index 9a0b276a4b..b7e217a35b 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -79,13 +79,13 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_function(self):
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader2.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
@@ -101,7 +101,7 @@ class SavedModelLoaderTest(test.TestCase):
with self.assertRaises(KeyError):
graph.get_tensor_by_name("z:0")
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
# Check that x and y are not initialized
with self.assertRaises(errors.FailedPreconditionError):
sess.run(x)
@@ -110,7 +110,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(
sess.graph, ["foo_graph"], import_scope="baz")
@@ -126,14 +126,14 @@ class SavedModelLoaderTest(test.TestCase):
# Test combined load function.
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"], import_scope="baa")
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
def test_restore_variables(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
x = variables.Variable(0, name="x")
y = variables.Variable(0, name="y")
z = x * y
@@ -151,7 +151,7 @@ class SavedModelLoaderTest(test.TestCase):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph()
saver, _ = loader.load_graph(graph, ["foo_graph"])
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
loader.restore_variables(sess, saver)
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
@@ -203,12 +203,12 @@ class SavedModelLoaderTest(test.TestCase):
builder.save()
loader = loader_impl.SavedModelLoader(path)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 00b669fc97..49d52d3bee 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -97,7 +97,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def({
@@ -110,7 +110,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_inputs_tensor_info_accept(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def({
@@ -121,7 +121,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_outputs_tensor_info_fail(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def(
@@ -133,7 +133,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_outputs_tensor_info_accept(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def(
@@ -153,7 +153,7 @@ class SavedModelTest(test.TestCase):
def testBadSavedModelFileFormat(self):
export_dir = self._get_export_dir("test_bad_saved_model_file_format")
# Attempt to load a SavedModel from an export directory that does not exist.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError,
"SavedModel file does not exist at: %s" %
export_dir):
@@ -164,7 +164,7 @@ class SavedModelTest(test.TestCase):
path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
with open(path_to_pb, "w") as f:
f.write("invalid content")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
constants.SAVED_MODEL_FILENAME_PB):
loader.load(sess, ["foo"], export_dir)
@@ -178,7 +178,7 @@ class SavedModelTest(test.TestCase):
constants.SAVED_MODEL_FILENAME_PBTXT)
with open(path_to_pbtxt, "w") as f:
f.write("invalid content")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
constants.SAVED_MODEL_FILENAME_PBTXT):
loader.load(sess, ["foo"], export_dir)
@@ -187,7 +187,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_verify_session_graph_usage")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
@@ -209,12 +209,12 @@ class SavedModelTest(test.TestCase):
# Expect an assertion error since add_meta_graph_and_variables() should be
# invoked before any add_meta_graph() calls.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"])
# Expect an assertion error for multiple calls of
# add_meta_graph_and_variables() since weights should be saved exactly once.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["bar"])
self.assertRaises(AssertionError, builder.add_meta_graph_and_variables,
@@ -227,35 +227,35 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph([tag_constants.SERVING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - multiple tags (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - multiple tags (from predefined constants for serving on TPU).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
builder.add_meta_graph(["foo", "bar"])
@@ -263,49 +263,49 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with a single predefined tag whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with a single predefined tag whose variables were not
# saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple predefined tags whose variables were not
# saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple predefined tags (for serving on TPU)
# whose variables were not saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple tags. Provide duplicate tags to test set
# semantics.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo", "bar", "foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Try restoring a graph with a non-existent tag. This should yield a runtime
# error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
export_dir)
# Try restoring a graph where a subset of the tags match. Since tag matching
# for meta graph defs follows "all" semantics, this should yield a runtime
# error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
export_dir)
@@ -315,7 +315,7 @@ class SavedModelTest(test.TestCase):
# Graph with two variables. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v1", 1)
self._init_and_validate_variable(sess, "v2", 2)
builder.add_meta_graph_and_variables(sess, ["foo"])
@@ -323,14 +323,14 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable (subset of the variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v2", 3)
builder.add_meta_graph(["bar"])
# Graph with a single variable (disjoint set of variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v3", 4)
builder.add_meta_graph(["baz"])
@@ -338,7 +338,7 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 2)
@@ -348,7 +348,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "bar", whose variables were not saved. Only the
# subset of the variables added to the graph will be restored with the
# checkpointed value.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 1)
@@ -357,7 +357,7 @@ class SavedModelTest(test.TestCase):
# Try restoring the graph with tag "baz", whose variables were not saved.
# Since this graph has a disjoint set of variables from the set that was
# saved, this should raise an error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
@@ -366,12 +366,12 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
constant_5_name = constant_op.constant(5.0).name
builder.add_meta_graph_and_variables(sess, ["foo"])
# Second graph with no variables
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
constant_6_name = constant_op.constant(6.0).name
builder.add_meta_graph(["bar"])
@@ -379,7 +379,7 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
@@ -388,7 +388,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(30.0, sess.run(c))
# Restore the graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
@@ -402,7 +402,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
@@ -410,7 +410,7 @@ class SavedModelTest(test.TestCase):
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -426,13 +426,13 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph(["bar"])
@@ -440,13 +440,13 @@ class SavedModelTest(test.TestCase):
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -457,7 +457,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v = variables.Variable(42, name="v")
ops.add_to_collection("foo_vars", v)
sess.run(variables.global_variables_initializer())
@@ -467,7 +467,7 @@ class SavedModelTest(test.TestCase):
# Graph with the same single variable added to a different collection.
# SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v = variables.Variable(43, name="v")
ops.add_to_collection("bar_vars", v)
sess.run(variables.global_variables_initializer())
@@ -480,7 +480,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "foo", whose variables were saved. The
# collection 'foo_vars' should contain a single element. The collection
# 'bar_vars' should not be found.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_foo_vars = ops.get_collection("foo_vars")
self.assertEqual(len(collection_foo_vars), 1)
@@ -493,7 +493,7 @@ class SavedModelTest(test.TestCase):
# reflect the new collection. The value of the variable in the
# collection-def corresponds to the saved value (from the previous graph
# with tag "foo").
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_bar_vars = ops.get_collection("bar_vars")
self.assertEqual(len(collection_bar_vars), 1)
@@ -507,7 +507,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable and a single entry in the signature def map.
# SavedModel is invoked to add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
foo_signature = signature_def_utils.build_signature_def(dict(),
@@ -517,7 +517,7 @@ class SavedModelTest(test.TestCase):
# Graph with the same single variable and multiple entries in the signature
# def map. No weights are saved by SavedModel.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
bar_signature = signature_def_utils.build_signature_def(dict(),
@@ -539,7 +539,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "foo". The single entry in the SignatureDef map
# corresponding to "foo_key" should exist.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -551,7 +551,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "bar". The SignatureDef map should have two
# entries. One corresponding to "bar_key" and another corresponding to the
# new value of "foo_key".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -610,7 +610,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection.
@@ -628,7 +628,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -643,7 +643,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -660,7 +660,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar bak",
@@ -674,7 +674,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_same_path")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -689,7 +689,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -709,7 +709,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_same_file")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -726,7 +726,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -746,7 +746,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_many_files")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
for i in range(5):
@@ -761,7 +761,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
for i in range(1, 5):
idx = str(i)
@@ -778,7 +778,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -801,7 +801,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -813,7 +813,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -835,7 +835,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -858,7 +858,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
g = ops.Graph()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -887,7 +887,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -905,7 +905,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(3, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -916,7 +916,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op_group")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -934,7 +934,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -945,7 +945,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op_after_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -964,12 +964,12 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir)
self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
@@ -977,7 +977,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection specific to `foo` graph.
@@ -988,7 +988,7 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection specific to `bar` graph.
@@ -1002,14 +1002,14 @@ class SavedModelTest(test.TestCase):
builder.save()
# Check assets restored for graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"foo.txt", "content_foo",
"asset_file_tensor:0")
# Check assets restored for graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self._validate_asset_collection(export_dir, bar_graph.collection_def,
"bar.txt", "content_bar",
@@ -1019,7 +1019,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_duplicate_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection with `foo.txt` that has `foo` specific
@@ -1031,7 +1031,7 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection with `foo.txt` that has `bar` specific
@@ -1046,14 +1046,14 @@ class SavedModelTest(test.TestCase):
builder.save()
# Check assets restored for graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"foo.txt", "content_foo",
"asset_file_tensor:0")
# Check assets restored for graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
# Validate the assets for `bar` graph. `foo.txt` should contain the
@@ -1139,7 +1139,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_custom_saver")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
custom_saver = training.Saver(name="my_saver")
@@ -1149,7 +1149,7 @@ class SavedModelTest(test.TestCase):
builder.save()
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, ["tag"], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue("my_saver/restore_all" in graph_ops)
@@ -1161,7 +1161,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_no_custom_saver")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
training.Saver(name="my_saver")
@@ -1171,7 +1171,7 @@ class SavedModelTest(test.TestCase):
builder.save()
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, ["tag"], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue("my_saver/restore_all" in graph_ops)
@@ -1183,7 +1183,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_multiple_custom_savers")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["tag_0"])
@@ -1199,7 +1199,7 @@ class SavedModelTest(test.TestCase):
def _validate_custom_saver(tag_name, saver_name):
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, [tag_name], export_dir)
self.assertEqual(
saved_graph.saver_def.restore_op_name,
@@ -1214,7 +1214,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Build a SavedModel with a variable, an asset, and a constant tensor.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
"asset_file_tensor")
@@ -1228,7 +1228,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Restore the SavedModel under an import_scope in a new graph/session.
graph_proto = loader.load(
sess, ["tag_name"], export_dir, import_scope="scope_name")
@@ -1281,7 +1281,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with a single predefined tag whose variables were saved
# without any device information.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py
index b2fa40d4f1..18f82daada 100644
--- a/tensorflow/python/saved_model/simple_save_test.py
+++ b/tensorflow/python/saved_model/simple_save_test.py
@@ -60,7 +60,7 @@ class SimpleSaveTest(test.TestCase):
# Initialize input and output variables and save a prediction graph using
# the default parameters.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var_x = self._init_and_validate_variable(sess, "var_x", 1)
var_y = self._init_and_validate_variable(sess, "var_y", 2)
inputs = {"x": var_x}
@@ -69,7 +69,7 @@ class SimpleSaveTest(test.TestCase):
# Restore the graph with a valid tag and check the global variables and
# signature def map.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
graph = loader.load(sess, [tag_constants.SERVING], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)