aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model/saved_model_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/saved_model/saved_model_test.py')
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index fb4732aca2..00b669fc97 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -846,9 +846,19 @@ class SavedModelTest(test.TestCase):
def testLegacyInitOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir(
"test_legacy_init_op_with_non_empty_collection")
+ self._testInitOpsWithNonEmptyCollection(
+ export_dir, constants.LEGACY_INIT_OP_KEY)
+
+ def testMainOpWithNonEmptyCollection(self):
+ export_dir = self._get_export_dir(
+ "test_main_op_with_non_empty_collection")
+ self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
+
+ def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ g = ops.Graph()
+ with self.test_session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -857,19 +867,21 @@ class SavedModelTest(test.TestCase):
v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
- # Set up an assignment op to be run as part of the legacy_init_op.
+ # Set up an assignment op to be run as part of the init op.
assign_v2 = state_ops.assign(v2, v1)
- legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op")
+ init_op = control_flow_ops.group(assign_v2, name="init_op")
sess.run(variables.global_variables_initializer())
- ops.add_to_collection(constants.LEGACY_INIT_OP_KEY,
- control_flow_ops.no_op())
- # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection
+ ops.add_to_collection(key, control_flow_ops.no_op())
+ # ValueError should be raised since the LEGACY_INIT_OP_KEY collection
# is not empty and we don't support multiple init ops.
- with self.assertRaises(AssertionError):
+ with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(
- sess, ["foo"], legacy_init_op=legacy_init_op)
+ sess, ["foo"], legacy_init_op=init_op)
+ # We shouldn't be able to add as MAIN_OP, either.
+ with self.assertRaisesRegexp(ValueError, "Graph already contains"):
+ builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
def testTrainOp(self):
export_dir = self._get_export_dir("test_train_op")