aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model/loader_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/saved_model/loader_impl.py')
-rw-r--r--tensorflow/python/saved_model/loader_impl.py55
1 files changed, 19 insertions, 36 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index e5f649fdab..16077f52fa 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -116,11 +116,14 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
return asset_tensor_dict
-def _get_main_op_tensor(meta_graph_def_to_load):
+def _get_main_op_tensor(
+ meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
+ init_op_key: name of collection to check; should be one of MAIN_OP_KEY
+ or the deprecated LEGACY_INIT_OP_KEY
Returns:
The main op tensor, if it exists and `None` otherwise.
@@ -131,38 +134,15 @@ def _get_main_op_tensor(meta_graph_def_to_load):
"""
collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None
- if constants.MAIN_OP_KEY in collection_def:
- main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
+ if init_op_key in collection_def:
+ main_ops = collection_def[init_op_key].node_list.value
if len(main_ops) != 1:
- raise RuntimeError("Expected exactly one SavedModel main op.")
- main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
+ raise RuntimeError("Expected exactly one SavedModel main op. "
+ "Found: {}".format(main_ops))
+ main_op_tensor = ops.get_collection(init_op_key)[0]
return main_op_tensor
-def _get_legacy_init_op_tensor(meta_graph_def_to_load):
- """Gets the legacy init op tensor, if one exists.
-
- Args:
- meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
-
- Returns:
- The legacy init op tensor, if it exists and `None` otherwise.
-
- Raises:
- RuntimeError: If the collection def corresponding to the legacy init op key
- has other than exactly one tensor.
- """
- collection_def = meta_graph_def_to_load.collection_def
- legacy_init_op_tensor = None
- if constants.LEGACY_INIT_OP_KEY in collection_def:
- legacy_init_ops = collection_def[
- constants.LEGACY_INIT_OP_KEY].node_list.value
- if len(legacy_init_ops) != 1:
- raise RuntimeError("Expected exactly one legacy serving init op.")
- legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
- return legacy_init_op_tensor
-
-
@tf_export("saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
@@ -284,12 +264,15 @@ class SavedModelLoader(object):
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
Returns:
- Saver defined by the MetaGraph, which can be used to restore the variable
- values.
+ A tuple of
+ * Saver defined by the MetaGraph, which can be used to restore the
+ variable values.
+ * List of `Operation`/`Tensor` objects returned from
+ `tf.import_graph_def` (may be `None`).
"""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
with graph.as_default():
- return tf_saver.import_meta_graph(
+ return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
meta_graph_def, import_scope=import_scope, **saver_kwargs)
def restore_variables(self, sess, saver, import_scope=None):
@@ -340,8 +323,8 @@ class SavedModelLoader(object):
self._export_dir, meta_graph_def, import_scope=import_scope)
main_op_tensor = (
- _get_main_op_tensor(meta_graph_def) or
- (_get_legacy_init_op_tensor(meta_graph_def)))
+ _get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or
+ _get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
@@ -361,8 +344,8 @@ class SavedModelLoader(object):
`MetagraphDef` proto of the graph that was loaded.
"""
with sess.graph.as_default():
- saver = self.load_graph(sess.graph, tags, import_scope,
- **saver_kwargs)
+ saver, _ = self.load_graph(sess.graph, tags, import_scope,
+ **saver_kwargs)
self.restore_variables(sess, saver, import_scope)
self.run_init_ops(sess, tags, import_scope)
return self.get_meta_graph_def_from_tags(tags)