diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-04 19:28:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-04 19:38:17 -0700 |
commit | a3e5b1628322102914a46a5fbfca2db5cb8b9e11 (patch) | |
tree | 82eb8b3bd4f33e9721bb7507f470ef8bc6aa8889 /tensorflow/python/saved_model | |
parent | 2c3bf9eff79156e32512e8d6da2179cd044167b8 (diff) |
Avoids adding duplicate legacy_init_op to the saved_model's exported meta graph.
Previously, when the user restores graph from one meta graph generated from
saved_model and then re-generates another saved model, the re-generated model
will be invalid because it will contain duplicate legacy_init_ops.
PiperOrigin-RevId: 171099152
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 7 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 30 |
2 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 73a3f9075d..16651ffebc 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -140,11 +140,16 @@ class SavedModelBuilder(object): Raises: TypeError if legacy init op is not of type `Operation`. + AssertionError if the graph already contains one or more legacy init ops. """ if legacy_init_op is not None: if not isinstance(legacy_init_op, ops.Operation): raise TypeError("legacy_init_op needs to be an Operation: %r" % legacy_init_op) + if ops.get_collection(constants.LEGACY_INIT_OP_KEY): + raise AssertionError( + "graph already contains one or more legacy init ops under the " + "collection {}.".format(constants.LEGACY_INIT_OP_KEY)) ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, legacy_init_op) def _add_main_op(self, main_op): @@ -258,7 +263,7 @@ class SavedModelBuilder(object): Raises: AssertionError: If the variables for the SavedModel have not been saved - yet. + yet, or if the graph already contains one or more legacy init ops. """ if not self._has_saved_variables: raise AssertionError( diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 5639e6855d..c6d2c32293 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -1,4 +1,4 @@ -## Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -637,6 +637,34 @@ class SavedModelTest(test.TestCase): # the legacy_init_op, following a restore. self.assertEqual(3, ops.get_collection("v")[2].eval()) + def testLegacyInitOpWithNonEmptyCollection(self): + export_dir = os.path.join(test.get_temp_dir(), + "test_legacy_init_op_with_non_empty_collection") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Initialize variable `v1` to 1. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + + # Initialize another variable `v2` to 42. + v2 = variables.Variable(42, name="v2", trainable=False, collections=[]) + ops.add_to_collection("v", v2) + + # Set up an assignment op to be run as part of the legacy_init_op. + assign_v2 = state_ops.assign(v2, v1) + legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op") + + sess.run(variables.global_variables_initializer()) + + ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, + control_flow_ops.no_op()) + # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection + # is not empty and we don't support multiple init ops. + with self.assertRaises(AssertionError): + builder.add_meta_graph_and_variables( + sess, ["foo"], legacy_init_op=legacy_init_op) + def testMultipleAssets(self): export_dir = os.path.join(test.get_temp_dir(), "test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) |