aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-06-15 18:04:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 18:13:04 -0700
commit68af4047fdfa89fa7b7d222a50a38eb0a469d946 (patch)
treefdaa28614a2a07db8c207541f8a88f1f833703ee /tensorflow/python/saved_model
parent1aebd982d7d911504dfd47b99a56461c67ceddad (diff)
Automated g4 rollback of changelist 200747752
PiperOrigin-RevId: 200802842
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/BUILD24
-rw-r--r--tensorflow/python/saved_model/loader_impl.py175
-rw-r--r--tensorflow/python/saved_model/loader_test.py180
3 files changed, 31 insertions, 348 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 076f2d8760..81786fbf43 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -87,30 +87,6 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:util",
- "//tensorflow/python:variables",
- ],
-)
-
-py_test(
- name = "loader_test",
- size = "small",
- srcs = ["loader_test.py"],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:private"],
- deps = [
- ":builder",
- ":loader",
- ":signature_def_utils",
- ":utils",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lib",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
],
)
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index 6770aaef36..d1bd8d47ae 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -28,7 +28,6 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
@@ -208,56 +207,11 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
"""
- loader = SavedModelLoader(export_dir)
- return loader.load(sess, tags, import_scope, **saver_kwargs)
-
-
-class SavedModelLoader(object):
- """Load graphs and restore variable values from a `SavedModel`."""
-
- def __init__(self, export_dir):
- """Creates a `SavedModelLoader`.
-
- Args:
- export_dir: Directory in which the SavedModel protocol buffer and
- variables to be loaded are located.
- """
- self._export_dir = export_dir
- self._variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
- self._saved_model = _parse_saved_model(export_dir)
-
- @property
- def export_dir(self):
- """Directory containing the SavedModel."""
- return self._export_dir
-
- @property
- def variables_path(self):
- """Path to variable checkpoint files."""
- return self._variables_path
-
- @property
- def saved_model(self):
- """SavedModel object parsed from the export directory."""
- return self._saved_model
-
- def get_meta_graph_def_from_tags(self, tags):
- """Return MetaGraphDef with the exact specified tags.
-
- Args:
- tags: A list or set of string tags that identify the MetaGraphDef.
-
- Returns:
- MetaGraphDef with the same tags.
-
- Raises:
- RuntimeError: if no metagraphs were found with the associated tags.
- """
+ with sess.graph.as_default():
+ # Build the SavedModel protocol buffer and find requested meta graph def.
+ saved_model = _parse_saved_model(export_dir)
found_match = False
- for meta_graph_def in self._saved_model.meta_graphs:
+ for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(tags):
meta_graph_def_to_load = meta_graph_def
found_match = True
@@ -269,99 +223,32 @@ class SavedModelLoader(object):
" could not be found in SavedModel. To inspect available tag-sets in"
" the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
)
- return meta_graph_def_to_load
- def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
- """Load ops and nodes from SavedModel MetaGraph into graph.
+ # Build a saver by importing the meta graph def to load.
+ saver = tf_saver.import_meta_graph(
+ meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
+
+ if saver:
+ # Build the checkpoint path where the variables are located.
+ variables_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY),
+ compat.as_bytes(constants.VARIABLES_FILENAME))
+
+ # Restore the variables using the built saver in the provided session.
+ saver.restore(sess, variables_path)
+ else:
+ tf_logging.info("The specified SavedModel has no variables; no "
+ "checkpoints were restored.")
+
+ # Get asset tensors, if any.
+ asset_tensors_dictionary = _get_asset_tensors(
+ export_dir, meta_graph_def_to_load, import_scope=import_scope)
+
+ main_op_tensor = (
+ _get_main_op_tensor(meta_graph_def_to_load) or
+ (_get_legacy_init_op_tensor(meta_graph_def_to_load)))
+ if main_op_tensor is not None:
+ sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
- Args:
- graph: tf.Graph object.
- tags: a set of string tags identifying a MetaGraphDef.
- import_scope: Optional `string` -- if specified, prepend this string
- followed by '/' to all loaded tensor names. This scope is applied to
- tensor instances loaded into the passed session, but it is *not* written
- through to the static `MetaGraphDef` protocol buffer that is returned.
- **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
-
- Returns:
- Saver defined by the MetaGraph, which can be used to restore the variable
- values.
- """
- meta_graph_def = self.get_meta_graph_def_from_tags(tags)
- with graph.as_default():
- return tf_saver.import_meta_graph(
- meta_graph_def, import_scope=import_scope, **saver_kwargs)
-
- def restore_variables(self, sess, saver, import_scope=None):
- """Restore SavedModel variable values into the session.
-
- Args:
- sess: tf.Session to restore variable values.
- saver: a tf.train.Saver object. Can be None if there are no variables in
- graph. This may be the saver returned by the load_graph() function, or a
- default `tf.train.Saver()`.
- import_scope: Optional `string` -- if specified, prepend this string
- followed by '/' to all loaded tensor names. This scope is applied to
- tensor instances loaded into the passed session, but it is *not* written
- through to the static `MetaGraphDef` protocol buffer that is returned.
-
- Raises:
- ValueError: if no saver was passed to the saver argument, and there are
- variables in the graph.
- """
- with sess.graph.as_default():
- if not variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
- tf_logging.info("The specified SavedModel has no variables; no "
- "checkpoints were restored.")
- elif isinstance(saver, tf_saver.Saver):
- saver.restore(sess, self._variables_path)
- else:
- raise ValueError(
- "No tf.train.Saver object was passed to the function "
- "SavedModelLoader.restore_variables. Since there are variables in "
- "the graph, a saver is required.")
-
- def run_init_ops(self, sess, tags, import_scope=None):
- """Run initialization ops defined in the `MetaGraphDef`.
-
- Args:
- sess: tf.Session to restore variable values.
- tags: a set of string tags identifying a MetaGraphDef.
- import_scope: Optional `string` -- if specified, prepend this string
- followed by '/' to all loaded tensor names. This scope is applied to
- tensor instances loaded into the passed session, but it is *not* written
- through to the static `MetaGraphDef` protocol buffer that is returned.
- """
- meta_graph_def = self.get_meta_graph_def_from_tags(tags)
- with sess.graph.as_default():
- # Get asset tensors, if any.
- asset_tensors_dictionary = _get_asset_tensors(
- self._export_dir, meta_graph_def, import_scope=import_scope)
-
- main_op_tensor = (
- _get_main_op_tensor(meta_graph_def) or
- (_get_legacy_init_op_tensor(meta_graph_def)))
- if main_op_tensor is not None:
- sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
-
- def load(self, sess, tags, import_scope=None, **saver_kwargs):
- """Load the MetaGraphDef graph and restore variable values into the session.
-
- Args:
- sess: tf.Session to restore variable values.
- tags: a set of string tags identifying a MetaGraphDef.
- import_scope: Optional `string` -- if specified, prepend this string
- followed by '/' to all loaded tensor names. This scope is applied to
- tensor instances loaded into the passed session, but it is *not* written
- through to the static `MetaGraphDef` protocol buffer that is returned.
- **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
-
- Returns:
- `MetagraphDef` proto of the graph that was loaded.
- """
- with sess.graph.as_default():
- saver = self.load_graph(sess.graph, tags, import_scope,
- **saver_kwargs)
- self.restore_variables(sess, saver, import_scope)
- self.run_init_ops(sess, tags, import_scope)
- return self.get_meta_graph_def_from_tags(tags)
+ return meta_graph_def_to_load
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
deleted file mode 100644
index 2ec2519c89..0000000000
--- a/tensorflow/python/saved_model/loader_test.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for SavedModelLoader class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.python.client import session
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import builder as saved_model_builder
-from tensorflow.python.saved_model import loader_impl
-from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.saved_model import utils
-from tensorflow.python.training import saver as tf_saver
-
-
-def _get_export_dir(label):
- return os.path.join(test.get_temp_dir(), label)
-
-SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
-SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
-
-
-class SavedModelLoaderTest(test.TestCase):
-
- def setUp(self):
- """Write test SavedModels to a temp directory."""
- with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x")
- y = variables.Variable(11, name="y")
- z = x + y
- sess.run(variables.global_variables_initializer())
-
- foo_sig_def = signature_def_utils.build_signature_def(
- {"foo_input": utils.build_tensor_info(x)},
- {"foo_output": utils.build_tensor_info(z)})
- bar_sig_def = signature_def_utils.build_signature_def(
- {"bar_x": utils.build_tensor_info(x),
- "bar_y": utils.build_tensor_info(y)},
- {"bar_z": utils.build_tensor_info(z)})
-
- builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
- builder.add_meta_graph_and_variables(
- sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
- builder.save()
-
- # Write SavedModel with a main_op
- assign_op = control_flow_ops.group(state_ops.assign(y, 7))
-
- builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
- builder.add_meta_graph_and_variables(
- sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
- main_op=assign_op)
- builder.save()
-
- def tearDown(self):
- file_io.delete_recursively(test.get_temp_dir())
-
- def test_load_function(self):
- loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- with self.test_session(graph=ops.Graph()) as sess:
- loader.load(sess, ["foo_graph"])
- self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
- self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
-
- loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
- loader2.load(sess, ["foo_graph"])
- self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
- self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
-
- def test_load_graph(self):
- loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- graph = ops.Graph()
- loader.load_graph(graph, ["foo_graph"])
-
- x = graph.get_tensor_by_name("x:0")
- y = graph.get_tensor_by_name("y:0")
-
- with self.assertRaises(KeyError):
- graph.get_tensor_by_name("z:0")
-
- with self.test_session(graph=graph) as sess:
- # Check that x and y are not initialized
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(x)
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(y)
-
- def test_load_with_import_scope(self):
- loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
- saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
-
- # The default saver should not work when the import scope is set.
- with self.assertRaises(errors.NotFoundError):
- loader.restore_variables(sess, tf_saver.Saver())
-
- loader.restore_variables(sess, saver)
- loader.run_init_ops(sess, ["foo_graph"])
-
- self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
- self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval())
-
- # Test combined load function.
- loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
- loader.load(sess, ["foo_graph"], import_scope="baa")
- self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
- self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
-
- def test_restore_variables(self):
- loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
- x = variables.Variable(0, name="x")
- y = variables.Variable(0, name="y")
- z = x * y
-
- sess.run(variables.global_variables_initializer())
-
- # There are variables to restore, so a saver must be created.
- with self.assertRaises(ValueError):
- loader.restore_variables(sess, None)
-
- loader.restore_variables(sess, tf_saver.Saver())
- self.assertEqual(55, z.eval())
-
- def test_run_init_op(self):
- loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- graph = ops.Graph()
- saver = loader.load_graph(graph, ["foo_graph"])
- with self.test_session(graph=graph) as sess:
- loader.restore_variables(sess, saver)
- self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
- self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
-
- loader.run_init_ops(sess, ["foo_graph"])
- self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
- self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
-
- def test_parse_saved_model(self):
- loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
- self.assertIsNotNone(meta_graph)
- self.assertIn("foo", meta_graph.signature_def)
- self.assertIn("bar", meta_graph.signature_def)
-
- def test_load_invalid_meta_graph(self):
- loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- with self.assertRaises(RuntimeError):
- loader.get_meta_graph_def_from_tags([])
- with self.assertRaises(RuntimeError):
- loader.get_meta_graph_def_from_tags([""])
- with self.assertRaises(RuntimeError):
- loader.get_meta_graph_def_from_tags(["not_a_graph"])
-
-
-if __name__ == "__main__":
- test.main()