aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-06-19 17:10:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 17:13:01 -0700
commitda861da63df724339e0148ff43192de05770a3c8 (patch)
tree3f2fa0aeb7e5305608a3c1a02ee2c700ee6580eb /tensorflow/python/saved_model
parente1a7a2ded90fbbdfc3a41954a332a04c73dd62c6 (diff)
Refactor loader.load function into a class that splits the graph loading and variable restoration steps.
PiperOrigin-RevId: 201268712
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/BUILD24
-rw-r--r--tensorflow/python/saved_model/loader_impl.py176
-rw-r--r--tensorflow/python/saved_model/loader_test.py217
3 files changed, 386 insertions, 31 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 81786fbf43..076f2d8760 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -87,6 +87,30 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "loader_test",
+ size = "small",
+ srcs = ["loader_test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":builder",
+ ":loader",
+ ":signature_def_utils",
+ ":utils",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
],
)
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index d1bd8d47ae..e5f649fdab 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -28,6 +28,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
@@ -207,11 +208,56 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
"""
- with sess.graph.as_default():
- # Build the SavedModel protocol buffer and find requested meta graph def.
- saved_model = _parse_saved_model(export_dir)
+ loader = SavedModelLoader(export_dir)
+ return loader.load(sess, tags, import_scope, **saver_kwargs)
+
+
+class SavedModelLoader(object):
+ """Load graphs and restore variable values from a `SavedModel`."""
+
+ def __init__(self, export_dir):
+ """Creates a `SavedModelLoader`.
+
+ Args:
+ export_dir: Directory in which the SavedModel protocol buffer and
+ variables to be loaded are located.
+ """
+ self._export_dir = export_dir
+ self._variables_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY),
+ compat.as_bytes(constants.VARIABLES_FILENAME))
+ self._saved_model = _parse_saved_model(export_dir)
+
+ @property
+ def export_dir(self):
+ """Directory containing the SavedModel."""
+ return self._export_dir
+
+ @property
+ def variables_path(self):
+ """Path to variable checkpoint files."""
+ return self._variables_path
+
+ @property
+ def saved_model(self):
+ """SavedModel object parsed from the export directory."""
+ return self._saved_model
+
+ def get_meta_graph_def_from_tags(self, tags):
+ """Return MetaGraphDef with the exact specified tags.
+
+ Args:
+ tags: A list or set of string tags that identify the MetaGraphDef.
+
+ Returns:
+ MetaGraphDef with the same tags.
+
+ Raises:
+ RuntimeError: if no metagraphs were found with the associated tags.
+ """
found_match = False
- for meta_graph_def in saved_model.meta_graphs:
+ for meta_graph_def in self._saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(tags):
meta_graph_def_to_load = meta_graph_def
found_match = True
@@ -223,32 +269,100 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
" could not be found in SavedModel. To inspect available tag-sets in"
" the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
)
+ return meta_graph_def_to_load
- # Build a saver by importing the meta graph def to load.
- saver = tf_saver.import_meta_graph(
- meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
-
- if saver:
- # Build the checkpoint path where the variables are located.
- variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
-
- # Restore the variables using the built saver in the provided session.
- saver.restore(sess, variables_path)
- else:
- tf_logging.info("The specified SavedModel has no variables; no "
- "checkpoints were restored.")
-
- # Get asset tensors, if any.
- asset_tensors_dictionary = _get_asset_tensors(
- export_dir, meta_graph_def_to_load, import_scope=import_scope)
-
- main_op_tensor = (
- _get_main_op_tensor(meta_graph_def_to_load) or
- (_get_legacy_init_op_tensor(meta_graph_def_to_load)))
- if main_op_tensor is not None:
- sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
+ def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
+ """Load ops and nodes from SavedModel MetaGraph into graph.
- return meta_graph_def_to_load
+ Args:
+ graph: tf.Graph object.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
+
+ Returns:
+ Saver defined by the MetaGraph, which can be used to restore the variable
+ values.
+ """
+ meta_graph_def = self.get_meta_graph_def_from_tags(tags)
+ with graph.as_default():
+ return tf_saver.import_meta_graph(
+ meta_graph_def, import_scope=import_scope, **saver_kwargs)
+
+ def restore_variables(self, sess, saver, import_scope=None):
+ """Restore SavedModel variable values into the session.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ saver: a tf.train.Saver object. Can be None if there are no variables in
+ graph. This may be the saver returned by the load_graph() function, or a
+ default `tf.train.Saver()`.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+
+ Raises:
+ ValueError: if no saver was passed to the saver argument, and there are
+ variables in the graph.
+ """
+ with sess.graph.as_default():
+ if (saver is None and
+ not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access
+ tf_logging.info("The specified SavedModel has no variables; no "
+ "checkpoints were restored.")
+ elif isinstance(saver, tf_saver.Saver):
+ saver.restore(sess, self._variables_path)
+ else:
+ raise ValueError(
+ "No tf.train.Saver object was passed to the function "
+ "SavedModelLoader.restore_variables. Since there are variables in "
+ "the graph, a saver is required.")
+
+ def run_init_ops(self, sess, tags, import_scope=None):
+ """Run initialization ops defined in the `MetaGraphDef`.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ """
+ meta_graph_def = self.get_meta_graph_def_from_tags(tags)
+ with sess.graph.as_default():
+ # Get asset tensors, if any.
+ asset_tensors_dictionary = _get_asset_tensors(
+ self._export_dir, meta_graph_def, import_scope=import_scope)
+
+ main_op_tensor = (
+ _get_main_op_tensor(meta_graph_def) or
+ (_get_legacy_init_op_tensor(meta_graph_def)))
+ if main_op_tensor is not None:
+ sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
+
+ def load(self, sess, tags, import_scope=None, **saver_kwargs):
+ """Load the MetaGraphDef graph and restore variable values into the session.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
+
+ Returns:
+ `MetagraphDef` proto of the graph that was loaded.
+ """
+ with sess.graph.as_default():
+ saver = self.load_graph(sess.graph, tags, import_scope,
+ **saver_kwargs)
+ self.restore_variables(sess, saver, import_scope)
+ self.run_init_ops(sess, tags, import_scope)
+ return self.get_meta_graph_def_from_tags(tags)
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
new file mode 100644
index 0000000000..ce18859f6b
--- /dev/null
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -0,0 +1,217 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModelLoader class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder as saved_model_builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import utils
+from tensorflow.python.training import saver as tf_saver
+
+
+def _get_export_dir(label):
+ return os.path.join(test.get_temp_dir(), label)
+
+SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
+SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
+
+
+class SavedModelLoaderTest(test.TestCase):
+
+ def setUp(self):
+ """Write test SavedModels to a temp directory."""
+ with session.Session(graph=ops.Graph()) as sess:
+ x = variables.Variable(5, name="x")
+ y = variables.Variable(11, name="y")
+ z = x + y
+ sess.run(variables.global_variables_initializer())
+
+ foo_sig_def = signature_def_utils.build_signature_def(
+ {"foo_input": utils.build_tensor_info(x)},
+ {"foo_output": utils.build_tensor_info(z)})
+ bar_sig_def = signature_def_utils.build_signature_def(
+ {"bar_x": utils.build_tensor_info(x),
+ "bar_y": utils.build_tensor_info(y)},
+ {"bar_z": utils.build_tensor_info(z)})
+
+ builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
+ builder.save()
+
+ # Write SavedModel with a main_op
+ assign_op = control_flow_ops.group(state_ops.assign(y, 7))
+
+ builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
+ main_op=assign_op)
+ builder.save()
+
+ def tearDown(self):
+ file_io.delete_recursively(test.get_temp_dir())
+
+ def test_load_function(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+ loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader2.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
+
+ def test_load_graph(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ graph = ops.Graph()
+ loader.load_graph(graph, ["foo_graph"])
+
+ x = graph.get_tensor_by_name("x:0")
+ y = graph.get_tensor_by_name("y:0")
+
+ with self.assertRaises(KeyError):
+ graph.get_tensor_by_name("z:0")
+
+ with self.test_session(graph=graph) as sess:
+ # Check that x and y are not initialized
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(x)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(y)
+
+ def test_load_with_import_scope(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
+
+ # The default saver should not work when the import scope is set.
+ with self.assertRaises(errors.NotFoundError):
+ loader.restore_variables(sess, tf_saver.Saver())
+
+ loader.restore_variables(sess, saver)
+ loader.run_init_ops(sess, ["foo_graph"])
+
+ self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval())
+
+ # Test combined load function.
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"], import_scope="baa")
+ self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
+
+ def test_restore_variables(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ x = variables.Variable(0, name="x")
+ y = variables.Variable(0, name="y")
+ z = x * y
+
+ sess.run(variables.global_variables_initializer())
+
+ # There are variables to restore, so a saver must be created.
+ with self.assertRaises(ValueError):
+ loader.restore_variables(sess, None)
+
+ loader.restore_variables(sess, tf_saver.Saver())
+ self.assertEqual(55, z.eval())
+
+ def test_run_init_op(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ graph = ops.Graph()
+ saver = loader.load_graph(graph, ["foo_graph"])
+ with self.test_session(graph=graph) as sess:
+ loader.restore_variables(sess, saver)
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+ loader.run_init_ops(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
+
+ def test_parse_saved_model(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
+ self.assertIsNotNone(meta_graph)
+ self.assertIn("foo", meta_graph.signature_def)
+ self.assertIn("bar", meta_graph.signature_def)
+
+ def test_load_invalid_meta_graph(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags([])
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags([""])
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags(["not_a_graph"])
+
+ def test_load_saved_model_with_no_variables(self):
+ """Test that SavedModel runs saver when there appear to be no variables.
+
+ When no variables are detected, this may mean that the variables were saved
+ to different collections, or the collections weren't saved to the
+ SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
+ run in either of these cases.
+ """
+ path = _get_export_dir("no_variable_saved_model")
+ with session.Session(graph=ops.Graph()) as sess:
+ x = variables.Variable(5, name="x", collections=["not_global_variable"])
+ y = variables.Variable(11, name="y", collections=["not_global_variable"])
+ self.assertFalse(variables._all_saveable_objects())
+ z = x + y
+ sess.run(variables.variables_initializer([x, y]))
+
+ foo_sig_def = signature_def_utils.build_signature_def(
+ {"foo_input": utils.build_tensor_info(x)},
+ {"foo_output": utils.build_tensor_info(z)})
+
+ builder = saved_model_builder.SavedModelBuilder(path)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def},
+ saver=tf_saver.Saver([x, y]))
+ builder.save()
+
+ loader = loader_impl.SavedModelLoader(path)
+ with self.test_session(graph=ops.Graph()) as sess:
+ saver = loader.load_graph(sess.graph, ["foo_graph"])
+ self.assertFalse(variables._all_saveable_objects())
+ self.assertIsNotNone(saver)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+
+if __name__ == "__main__":
+ test.main()