aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/strip_unused_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/strip_unused_lib.py')
-rw-r--r--tensorflow/python/tools/strip_unused_lib.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/tools/strip_unused_lib.py b/tensorflow/python/tools/strip_unused_lib.py
index 8f9e20ab8e..b1d1956076 100644
--- a/tensorflow/python/tools/strip_unused_lib.py
+++ b/tensorflow/python/tools/strip_unused_lib.py
@@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
a list that specifies one value per input node name.
Returns:
- A GraphDef with all unnecessary ops removed.
+ A `GraphDef` with all unnecessary ops removed.
+
+ Raises:
+ ValueError: If any element in `input_node_names` refers to a tensor instead
+ of an operation.
+ KeyError: If any element in `input_node_names` is not found in the graph.
"""
+ for name in input_node_names:
+ if ":" in name:
+ raise ValueError("Name '%s' appears to refer to a Tensor, "
+ "not a Operation." % name)
+
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
+ not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
+ not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
@@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
+ if not_found:
+ raise KeyError("The following input nodes were not found: %s\n" % not_found)
+
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def