aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/meta_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/meta_graph.py')
-rw-r--r--tensorflow/python/framework/meta_graph.py68
1 files changed, 65 insertions, 3 deletions
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
index 923e76fc9c..33631282bd 100644
--- a/tensorflow/python/framework/meta_graph.py
+++ b/tensorflow/python/framework/meta_graph.py
@@ -696,6 +696,67 @@ def import_scoped_meta_graph(meta_graph_or_file,
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
+ return import_scoped_meta_graph_with_return_elements(
+ meta_graph_or_file, clear_devices, graph, import_scope, input_map,
+ unbound_inputs_col_name, restore_collections_predicate)[0]
+
+
+def import_scoped_meta_graph_with_return_elements(
+ meta_graph_or_file,
+ clear_devices=False,
+ graph=None,
+ import_scope=None,
+ input_map=None,
+ unbound_inputs_col_name="unbound_inputs",
+ restore_collections_predicate=(lambda key: True),
+ return_elements=None):
+ """Imports graph from `MetaGraphDef` and returns vars and return elements.
+
+ This function takes a `MetaGraphDef` protocol buffer as input. If
+ the argument is a file containing a `MetaGraphDef` protocol buffer ,
+ it constructs a protocol buffer from the file content. The function
+ then adds all the nodes from the `graph_def` field to the
+ current graph, recreates the desired collections, and returns a dictionary of
+ all the Variables imported into the name scope.
+
+ In combination with `export_scoped_meta_graph()`, this function can be used to
+
+ * Serialize a graph along with other Python objects such as `QueueRunner`,
+ `Variable` into a `MetaGraphDef`.
+
+ * Restart training from a saved graph and checkpoints.
+
+ * Run inference from a saved graph and checkpoints.
+
+ Args:
+ meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
+ the path) containing a `MetaGraphDef`.
+ clear_devices: Boolean which controls whether to clear device information
+ from graph_def. Default false.
+ graph: The `Graph` to import into. If `None`, use the default graph.
+ import_scope: Optional `string`. Name scope into which to import the
+ subgraph. If `None`, the graph is imported to the root name scope.
+ input_map: A dictionary mapping input names (as strings) in `graph_def` to
+ `Tensor` objects. The values of the named input tensors in the imported
+ graph will be re-mapped to the respective `Tensor` values.
+ unbound_inputs_col_name: Collection name for looking up unbound inputs.
+ restore_collections_predicate: a predicate on collection names. A collection
+ named c (i.e whose key is c) will be restored iff
+ 1) `restore_collections_predicate(c)` is True, and
+ 2) `c != unbound_inputs_col_name`.
+ return_elements: A list of strings containing operation names in the
+ `MetaGraphDef` that will be returned as `Operation` objects; and/or
+ tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
+
+ Returns:
+ A tuple of (
+ dictionary of all the `Variables` imported into the name scope,
+ list of `Operation` or `Tensor` objects from the `return_elements` list).
+
+ Raises:
+ ValueError: If the graph_def contains unbound inputs.
+
+ """
if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled.")
@@ -737,11 +798,12 @@ def import_scoped_meta_graph(meta_graph_or_file,
scope_to_prepend_to_names = graph.unique_name(
import_scope or "", mark_as_used=False)
- importer.import_graph_def(
+ imported_return_elements = importer.import_graph_def(
input_graph_def,
name=(import_scope or scope_to_prepend_to_names),
input_map=input_map,
- producer_op_list=producer_op_list)
+ producer_op_list=producer_op_list,
+ return_elements=return_elements)
# Restores all the other collections.
variable_objects = {}
@@ -806,7 +868,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
for v in variables:
var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
- return var_list
+ return var_list, imported_return_elements
def export_scoped_meta_graph(filename=None,