diff options
author | Karmel Allison <karmel@google.com> | 2018-06-06 16:06:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-06 16:08:43 -0700 |
commit | 64204dd0addea52368400eea6c67616c312b594d (patch) | |
tree | dc6c18f535319001341239c9087f459929ffc5fd /tensorflow/python/saved_model | |
parent | 617405d989a13839a585c82f9d09f03cbd080d0e (diff) |
Allow SavedModelBuilder to use custom Savers, and pass custom Savers included
in Estimator model functions through to the Builder when saving.
PiperOrigin-RevId: 199546645
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 46 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 75 |
3 files changed, 102 insertions, 20 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 2609a5d222..81786fbf43 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -149,6 +149,7 @@ py_test( "//tensorflow/python:saver_test_utils", "//tensorflow/python:state_ops", "//tensorflow/python:test_ops", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 24a13c0f33..e58be804c2 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -270,6 +270,18 @@ class SavedModelBuilder(object): self._add_train_op(train_op) + def _maybe_create_saver(self, saver=None): + """Creates a sharded saver if one does not already exist.""" + if not saver: + # Initialize a saver to generate a sharded output for all saveables in the + # current scope. + saver = tf_saver.Saver( + variables._all_saveable_objects(), # pylint: disable=protected-access + sharded=True, + write_version=saver_pb2.SaverDef.V2, + allow_empty=True) + return saver + def add_meta_graph(self, tags, signature_def_map=None, @@ -277,7 +289,8 @@ class SavedModelBuilder(object): legacy_init_op=None, clear_devices=False, main_op=None, - strip_default_attrs=False): + strip_default_attrs=False, + saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel. @@ -302,6 +315,9 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + saver: An instance of tf.train.Saver that will be used to export the + metagraph. If None, a sharded Saver that restores all variables will + be used. Raises: AssertionError: If the variables for the SavedModel have not been saved @@ -320,18 +336,11 @@ class SavedModelBuilder(object): # Add assets and ops self._add_collections(assets_collection, legacy_init_op, main_op, None) - # Initialize a saver to generate a sharded output for all saveables in the - # current scope. - saver = tf_saver.Saver( - variables._all_saveable_objects(), # pylint: disable=protected-access - sharded=True, - write_version=saver_pb2.SaverDef.V2, - allow_empty=True) + saver = self._maybe_create_saver(saver) # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another - # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. Removing the preexisting ones was the + # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. @@ -350,7 +359,8 @@ class SavedModelBuilder(object): legacy_init_op=None, clear_devices=False, main_op=None, - strip_default_attrs=False): + strip_default_attrs=False, + saver=None): # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. @@ -377,6 +387,9 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + saver: An instance of tf.train.Saver that will be used to export the + metagraph and save variables. If None, a sharded Saver that restores + all variables will be used. """ # pylint: enable=line-too-long @@ -403,13 +416,7 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - # Initialize a saver to generate a sharded output for all saveables in the - # current scope. - saver = tf_saver.Saver( - variables._all_saveable_objects(), # pylint: disable=protected-access - sharded=True, - write_version=saver_pb2.SaverDef.V2, - allow_empty=True) + saver = self._maybe_create_saver(saver) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a @@ -421,8 +428,7 @@ class SavedModelBuilder(object): # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another - # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. Removing the preexisting ones was the + # for the model weights). Removing the preexisting ones was the # motivation for the clear_extraneous_savers option, but it turns out that # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 7302c77ad5..effb38283b 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -44,6 +44,7 @@ from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver_test_utils +from tensorflow.python.training import training from tensorflow.python.util import compat SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") @@ -1122,6 +1123,80 @@ class SavedModelTest(test.TestCase): self.assertEqual(b"k1", v1.keys().eval()) self.assertEqual(3.0, v1.values().eval()) + def testCustomSaver(self): + export_dir = self._get_export_dir("test_custom_saver") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + custom_saver = training.Saver(name="my_saver") + builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver) + + # Save the SavedModel to disk. + builder.save() + + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, ["tag"], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue("my_saver/restore_all" in graph_ops) + self.assertFalse("save/restore_all" in graph_ops) + self.assertEqual( + saved_graph.saver_def.restore_op_name, "my_saver/restore_all") + + def testNoCustomSaver(self): + export_dir = self._get_export_dir("test_no_custom_saver") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + training.Saver(name="my_saver") + builder.add_meta_graph_and_variables(sess, ["tag"]) + + # Save the SavedModel to disk. + builder.save() + + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, ["tag"], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue("my_saver/restore_all" in graph_ops) + self.assertTrue("save/restore_all" in graph_ops) + self.assertEqual( + saved_graph.saver_def.restore_op_name, "save/restore_all") + + def testMultipleCustomSavers(self): + export_dir = self._get_export_dir("test_multiple_custom_savers") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + variables.Variable(1, name="v1") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["tag_0"]) + + saver_1 = training.Saver() + builder.add_meta_graph(["tag_1"], saver=saver_1) + + saver_2 = training.Saver() + builder.add_meta_graph(["tag_2"], saver=saver_2) + + # Save the SavedModel to disk. + builder.save() + + def _validate_custom_saver(tag_name, saver_name): + with ops.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + saved_graph = loader.load(sess, [tag_name], export_dir) + self.assertEqual( + saved_graph.saver_def.restore_op_name, + saver_name) + + _validate_custom_saver("tag_0", "save/restore_all") + _validate_custom_saver("tag_1", "save_1/restore_all") + _validate_custom_saver("tag_2", "save_2/restore_all") + def testClearDevices(self): export_dir = self._get_export_dir("test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) |