aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/strip_unused_lib.py
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-05-12 05:29:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-12 05:33:11 -0700
commitebb278520add4b046e283f81398df03395b5d342 (patch)
treec1cedf934122beab6e69751e6876f5e8df6e0c95 /tensorflow/python/tools/strip_unused_lib.py
parent8a13f77edad2c39e10dddd0043c8b97f6e08751d (diff)
Give clear errors for bad input names.
PiperOrigin-RevId: 155857515
Diffstat (limited to 'tensorflow/python/tools/strip_unused_lib.py')
-rw-r--r--tensorflow/python/tools/strip_unused_lib.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/tools/strip_unused_lib.py b/tensorflow/python/tools/strip_unused_lib.py
index 8f9e20ab8e..b1d1956076 100644
--- a/tensorflow/python/tools/strip_unused_lib.py
+++ b/tensorflow/python/tools/strip_unused_lib.py
@@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
a list that specifies one value per input node name.
Returns:
- A GraphDef with all unnecessary ops removed.
+ A `GraphDef` with all unnecessary ops removed.
+
+ Raises:
+ ValueError: If any element in `input_node_names` refers to a tensor instead
+ of an operation.
+ KeyError: If any element in `input_node_names` is not found in the graph.
"""
+ for name in input_node_names:
+ if ":" in name:
+ raise ValueError("Name '%s' appears to refer to a Tensor, "
+ "not a Operation." % name)
+
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
+ not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
+ not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
@@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
+ if not_found:
+ raise KeyError("The following input nodes were not found: %s\n" % not_found)
+
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def