aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-04 19:28:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 19:38:17 -0700
commita3e5b1628322102914a46a5fbfca2db5cb8b9e11 (patch)
tree82eb8b3bd4f33e9721bb7507f470ef8bc6aa8889 /tensorflow/python/saved_model
parent2c3bf9eff79156e32512e8d6da2179cd044167b8 (diff)
Avoids adding duplicate legacy_init_op to the saved_model's exported meta graph.
Previously, when the user restores graph from one meta graph generated from saved_model and then re-generates another saved model, the re-generated model will be invalid because it will contain duplicate legacy_init_ops. PiperOrigin-RevId: 171099152
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py7
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py30
2 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 73a3f9075d..16651ffebc 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -140,11 +140,16 @@ class SavedModelBuilder(object):
Raises:
TypeError if legacy init op is not of type `Operation`.
+ AssertionError if the graph already contains one or more legacy init ops.
"""
if legacy_init_op is not None:
if not isinstance(legacy_init_op, ops.Operation):
raise TypeError("legacy_init_op needs to be an Operation: %r" %
legacy_init_op)
+ if ops.get_collection(constants.LEGACY_INIT_OP_KEY):
+ raise AssertionError(
+ "graph already contains one or more legacy init ops under the "
+ "collection {}.".format(constants.LEGACY_INIT_OP_KEY))
ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, legacy_init_op)
def _add_main_op(self, main_op):
@@ -258,7 +263,7 @@ class SavedModelBuilder(object):
Raises:
AssertionError: If the variables for the SavedModel have not been saved
- yet.
+ yet, or if the graph already contains one or more legacy init ops.
"""
if not self._has_saved_variables:
raise AssertionError(
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 5639e6855d..c6d2c32293 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -1,4 +1,4 @@
-## Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -637,6 +637,34 @@ class SavedModelTest(test.TestCase):
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
+ def testLegacyInitOpWithNonEmptyCollection(self):
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_legacy_init_op_with_non_empty_collection")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Initialize variable `v1` to 1.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+
+ # Initialize another variable `v2` to 42.
+ v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
+ ops.add_to_collection("v", v2)
+
+ # Set up an assignment op to be run as part of the legacy_init_op.
+ assign_v2 = state_ops.assign(v2, v1)
+ legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op")
+
+ sess.run(variables.global_variables_initializer())
+
+ ops.add_to_collection(constants.LEGACY_INIT_OP_KEY,
+ control_flow_ops.no_op())
+ # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection
+ # is not empty and we don't support multiple init ops.
+ with self.assertRaises(AssertionError):
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], legacy_init_op=legacy_init_op)
+
def testMultipleAssets(self):
export_dir = os.path.join(test.get_temp_dir(), "test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)