diff options
author | 2017-05-12 05:29:21 -0700 | |
---|---|---|
committer | 2017-05-12 05:33:11 -0700 | |
commit | ebb278520add4b046e283f81398df03395b5d342 (patch) | |
tree | c1cedf934122beab6e69751e6876f5e8df6e0c95 /tensorflow/python/tools/strip_unused_lib.py | |
parent | 8a13f77edad2c39e10dddd0043c8b97f6e08751d (diff) |
Give clear errors for bad input names.
PiperOrigin-RevId: 155857515
Diffstat (limited to 'tensorflow/python/tools/strip_unused_lib.py')
-rw-r--r-- | tensorflow/python/tools/strip_unused_lib.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/tools/strip_unused_lib.py b/tensorflow/python/tools/strip_unused_lib.py index 8f9e20ab8e..b1d1956076 100644 --- a/tensorflow/python/tools/strip_unused_lib.py +++ b/tensorflow/python/tools/strip_unused_lib.py @@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names, a list that specifies one value per input node name. Returns: - A GraphDef with all unnecessary ops removed. + A `GraphDef` with all unnecessary ops removed. + + Raises: + ValueError: If any element in `input_node_names` refers to a tensor instead + of an operation. + KeyError: If any element in `input_node_names` is not found in the graph. """ + for name in input_node_names: + if ":" in name: + raise ValueError("Name '%s' appears to refer to a Tensor, " + "not a Operation." % name) + # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). + not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: + not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name @@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names, else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) + if not_found: + raise KeyError("The following input nodes were not found: %s\n" % not_found) + output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def |