aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-06-06 16:06:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-06 16:08:43 -0700
commit64204dd0addea52368400eea6c67616c312b594d (patch)
treedc6c18f535319001341239c9087f459929ffc5fd /tensorflow/python/saved_model
parent617405d989a13839a585c82f9d09f03cbd080d0e (diff)
Allow SavedModelBuilder to use custom Savers, and pass custom Savers included
in Estimator model functions through to the Builder when saving. PiperOrigin-RevId: 199546645
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/builder_impl.py46
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py75
3 files changed, 102 insertions, 20 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 2609a5d222..81786fbf43 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -149,6 +149,7 @@ py_test(
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 24a13c0f33..e58be804c2 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -270,6 +270,18 @@ class SavedModelBuilder(object):
self._add_train_op(train_op)
+ def _maybe_create_saver(self, saver=None):
+ """Creates a sharded saver if one does not already exist."""
+ if not saver:
+ # Initialize a saver to generate a sharded output for all saveables in the
+ # current scope.
+ saver = tf_saver.Saver(
+ variables._all_saveable_objects(), # pylint: disable=protected-access
+ sharded=True,
+ write_version=saver_pb2.SaverDef.V2,
+ allow_empty=True)
+ return saver
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -277,7 +289,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel.
@@ -302,6 +315,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph. If None, a sharded Saver that restores all variables will
+ be used.
Raises:
AssertionError: If the variables for the SavedModel have not been saved
@@ -320,18 +336,11 @@ class SavedModelBuilder(object):
# Add assets and ops
self._add_collections(assets_collection, legacy_init_op, main_op, None)
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
@@ -350,7 +359,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel and saves variables.
@@ -377,6 +387,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph and save variables. If None, a sharded Saver that restores
+ all variables will be used.
"""
# pylint: enable=line-too-long
@@ -403,13 +416,7 @@ class SavedModelBuilder(object):
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# Save the variables. Also, disable writing the checkpoint state proto. The
# file is not used during SavedModel loading. In addition, since a
@@ -421,8 +428,7 @@ class SavedModelBuilder(object):
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 7302c77ad5..effb38283b 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training
from tensorflow.python.util import compat
SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
@@ -1122,6 +1123,80 @@ class SavedModelTest(test.TestCase):
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
+ def testCustomSaver(self):
+ export_dir = self._get_export_dir("test_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ custom_saver = training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertFalse("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "my_saver/restore_all")
+
+ def testNoCustomSaver(self):
+ export_dir = self._get_export_dir("test_no_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertTrue("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "save/restore_all")
+
+ def testMultipleCustomSavers(self):
+ export_dir = self._get_export_dir("test_multiple_custom_savers")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(sess, ["tag_0"])
+
+ saver_1 = training.Saver()
+ builder.add_meta_graph(["tag_1"], saver=saver_1)
+
+ saver_2 = training.Saver()
+ builder.add_meta_graph(["tag_2"], saver=saver_2)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ def _validate_custom_saver(tag_name, saver_name):
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, [tag_name], export_dir)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name,
+ saver_name)
+
+ _validate_custom_saver("tag_0", "save/restore_all")
+ _validate_custom_saver("tag_1", "save_1/restore_all")
+ _validate_custom_saver("tag_2", "save_2/restore_all")
+
def testClearDevices(self):
export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)