diff options
author | Katherine Wu <kathywu@google.com> | 2018-06-15 18:04:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-15 18:13:04 -0700 |
commit | 68af4047fdfa89fa7b7d222a50a38eb0a469d946 (patch) | |
tree | fdaa28614a2a07db8c207541f8a88f1f833703ee /tensorflow/python/saved_model | |
parent | 1aebd982d7d911504dfd47b99a56461c67ceddad (diff) |
Automated g4 rollback of changelist 200747752
PiperOrigin-RevId: 200802842
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 175 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader_test.py | 180 |
3 files changed, 31 insertions, 348 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 076f2d8760..81786fbf43 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -87,30 +87,6 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python:variables", - ], -) - -py_test( - name = "loader_test", - size = "small", - srcs = ["loader_test.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:private"], - deps = [ - ":builder", - ":loader", - ":signature_def_utils", - ":utils", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", ], ) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 6770aaef36..d1bd8d47ae 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -28,7 +28,6 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver @@ -208,56 +207,11 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): Raises: RuntimeError: MetaGraphDef associated with the tags cannot be found. """ - loader = SavedModelLoader(export_dir) - return loader.load(sess, tags, import_scope, **saver_kwargs) - - -class SavedModelLoader(object): - """Load graphs and restore variable values from a `SavedModel`.""" - - def __init__(self, export_dir): - """Creates a `SavedModelLoader`. - - Args: - export_dir: Directory in which the SavedModel protocol buffer and - variables to be loaded are located. - """ - self._export_dir = export_dir - self._variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) - self._saved_model = _parse_saved_model(export_dir) - - @property - def export_dir(self): - """Directory containing the SavedModel.""" - return self._export_dir - - @property - def variables_path(self): - """Path to variable checkpoint files.""" - return self._variables_path - - @property - def saved_model(self): - """SavedModel object parsed from the export directory.""" - return self._saved_model - - def get_meta_graph_def_from_tags(self, tags): - """Return MetaGraphDef with the exact specified tags. - - Args: - tags: A list or set of string tags that identify the MetaGraphDef. - - Returns: - MetaGraphDef with the same tags. - - Raises: - RuntimeError: if no metagraphs were found with the associated tags. - """ + with sess.graph.as_default(): + # Build the SavedModel protocol buffer and find requested meta graph def. + saved_model = _parse_saved_model(export_dir) found_match = False - for meta_graph_def in self._saved_model.meta_graphs: + for meta_graph_def in saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(tags): meta_graph_def_to_load = meta_graph_def found_match = True @@ -269,99 +223,32 @@ class SavedModelLoader(object): " could not be found in SavedModel. To inspect available tag-sets in" " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" ) - return meta_graph_def_to_load - def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): - """Load ops and nodes from SavedModel MetaGraph into graph. + # Build a saver by importing the meta graph def to load. + saver = tf_saver.import_meta_graph( + meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs) + + if saver: + # Build the checkpoint path where the variables are located. + variables_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.VARIABLES_DIRECTORY), + compat.as_bytes(constants.VARIABLES_FILENAME)) + + # Restore the variables using the built saver in the provided session. + saver.restore(sess, variables_path) + else: + tf_logging.info("The specified SavedModel has no variables; no " + "checkpoints were restored.") + + # Get asset tensors, if any. + asset_tensors_dictionary = _get_asset_tensors( + export_dir, meta_graph_def_to_load, import_scope=import_scope) + + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def_to_load) or + (_get_legacy_init_op_tensor(meta_graph_def_to_load))) + if main_op_tensor is not None: + sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - Args: - graph: tf.Graph object. - tags: a set of string tags identifying a MetaGraphDef. - import_scope: Optional `string` -- if specified, prepend this string - followed by '/' to all loaded tensor names. This scope is applied to - tensor instances loaded into the passed session, but it is *not* written - through to the static `MetaGraphDef` protocol buffer that is returned. - **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. - - Returns: - Saver defined by the MetaGraph, which can be used to restore the variable - values. - """ - meta_graph_def = self.get_meta_graph_def_from_tags(tags) - with graph.as_default(): - return tf_saver.import_meta_graph( - meta_graph_def, import_scope=import_scope, **saver_kwargs) - - def restore_variables(self, sess, saver, import_scope=None): - """Restore SavedModel variable values into the session. - - Args: - sess: tf.Session to restore variable values. - saver: a tf.train.Saver object. Can be None if there are no variables in - graph. This may be the saver returned by the load_graph() function, or a - default `tf.train.Saver()`. - import_scope: Optional `string` -- if specified, prepend this string - followed by '/' to all loaded tensor names. This scope is applied to - tensor instances loaded into the passed session, but it is *not* written - through to the static `MetaGraphDef` protocol buffer that is returned. - - Raises: - ValueError: if no saver was passed to the saver argument, and there are - variables in the graph. - """ - with sess.graph.as_default(): - if not variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access - tf_logging.info("The specified SavedModel has no variables; no " - "checkpoints were restored.") - elif isinstance(saver, tf_saver.Saver): - saver.restore(sess, self._variables_path) - else: - raise ValueError( - "No tf.train.Saver object was passed to the function " - "SavedModelLoader.restore_variables. Since there are variables in " - "the graph, a saver is required.") - - def run_init_ops(self, sess, tags, import_scope=None): - """Run initialization ops defined in the `MetaGraphDef`. - - Args: - sess: tf.Session to restore variable values. - tags: a set of string tags identifying a MetaGraphDef. - import_scope: Optional `string` -- if specified, prepend this string - followed by '/' to all loaded tensor names. This scope is applied to - tensor instances loaded into the passed session, but it is *not* written - through to the static `MetaGraphDef` protocol buffer that is returned. - """ - meta_graph_def = self.get_meta_graph_def_from_tags(tags) - with sess.graph.as_default(): - # Get asset tensors, if any. - asset_tensors_dictionary = _get_asset_tensors( - self._export_dir, meta_graph_def, import_scope=import_scope) - - main_op_tensor = ( - _get_main_op_tensor(meta_graph_def) or - (_get_legacy_init_op_tensor(meta_graph_def))) - if main_op_tensor is not None: - sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - - def load(self, sess, tags, import_scope=None, **saver_kwargs): - """Load the MetaGraphDef graph and restore variable values into the session. - - Args: - sess: tf.Session to restore variable values. - tags: a set of string tags identifying a MetaGraphDef. - import_scope: Optional `string` -- if specified, prepend this string - followed by '/' to all loaded tensor names. This scope is applied to - tensor instances loaded into the passed session, but it is *not* written - through to the static `MetaGraphDef` protocol buffer that is returned. - **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. - - Returns: - `MetagraphDef` proto of the graph that was loaded. - """ - with sess.graph.as_default(): - saver = self.load_graph(sess.graph, tags, import_scope, - **saver_kwargs) - self.restore_variables(sess, saver, import_scope) - self.run_init_ops(sess, tags, import_scope) - return self.get_meta_graph_def_from_tags(tags) + return meta_graph_def_to_load diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py deleted file mode 100644 index 2ec2519c89..0000000000 --- a/tensorflow/python/saved_model/loader_test.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for SavedModelLoader class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.python.client import session -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.saved_model import utils -from tensorflow.python.training import saver as tf_saver - - -def _get_export_dir(label): - return os.path.join(test.get_temp_dir(), label) - -SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model") -SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op") - - -class SavedModelLoaderTest(test.TestCase): - - def setUp(self): - """Write test SavedModels to a temp directory.""" - with session.Session(graph=ops.Graph()) as sess: - x = variables.Variable(5, name="x") - y = variables.Variable(11, name="y") - z = x + y - sess.run(variables.global_variables_initializer()) - - foo_sig_def = signature_def_utils.build_signature_def( - {"foo_input": utils.build_tensor_info(x)}, - {"foo_output": utils.build_tensor_info(z)}) - bar_sig_def = signature_def_utils.build_signature_def( - {"bar_x": utils.build_tensor_info(x), - "bar_y": utils.build_tensor_info(y)}, - {"bar_z": utils.build_tensor_info(z)}) - - builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL) - builder.add_meta_graph_and_variables( - sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}) - builder.save() - - # Write SavedModel with a main_op - assign_op = control_flow_ops.group(state_ops.assign(y, 7)) - - builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP) - builder.add_meta_graph_and_variables( - sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}, - main_op=assign_op) - builder.save() - - def tearDown(self): - file_io.delete_recursively(test.get_temp_dir()) - - def test_load_function(self): - loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - with self.test_session(graph=ops.Graph()) as sess: - loader.load(sess, ["foo_graph"]) - self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) - self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) - - loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: - loader2.load(sess, ["foo_graph"]) - self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) - self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) - - def test_load_graph(self): - loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - graph = ops.Graph() - loader.load_graph(graph, ["foo_graph"]) - - x = graph.get_tensor_by_name("x:0") - y = graph.get_tensor_by_name("y:0") - - with self.assertRaises(KeyError): - graph.get_tensor_by_name("z:0") - - with self.test_session(graph=graph) as sess: - # Check that x and y are not initialized - with self.assertRaises(errors.FailedPreconditionError): - sess.run(x) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(y) - - def test_load_with_import_scope(self): - loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: - saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz") - - # The default saver should not work when the import scope is set. - with self.assertRaises(errors.NotFoundError): - loader.restore_variables(sess, tf_saver.Saver()) - - loader.restore_variables(sess, saver) - loader.run_init_ops(sess, ["foo_graph"]) - - self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval()) - self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval()) - - # Test combined load function. - loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: - loader.load(sess, ["foo_graph"], import_scope="baa") - self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) - self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) - - def test_restore_variables(self): - loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - with self.test_session(graph=ops.Graph()) as sess: - x = variables.Variable(0, name="x") - y = variables.Variable(0, name="y") - z = x * y - - sess.run(variables.global_variables_initializer()) - - # There are variables to restore, so a saver must be created. - with self.assertRaises(ValueError): - loader.restore_variables(sess, None) - - loader.restore_variables(sess, tf_saver.Saver()) - self.assertEqual(55, z.eval()) - - def test_run_init_op(self): - loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) - graph = ops.Graph() - saver = loader.load_graph(graph, ["foo_graph"]) - with self.test_session(graph=graph) as sess: - loader.restore_variables(sess, saver) - self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) - self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) - - loader.run_init_ops(sess, ["foo_graph"]) - self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) - self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) - - def test_parse_saved_model(self): - loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"]) - self.assertIsNotNone(meta_graph) - self.assertIn("foo", meta_graph.signature_def) - self.assertIn("bar", meta_graph.signature_def) - - def test_load_invalid_meta_graph(self): - loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - with self.assertRaises(RuntimeError): - loader.get_meta_graph_def_from_tags([]) - with self.assertRaises(RuntimeError): - loader.get_meta_graph_def_from_tags([""]) - with self.assertRaises(RuntimeError): - loader.get_meta_graph_def_from_tags(["not_a_graph"]) - - -if __name__ == "__main__": - test.main() |