diff options
Diffstat (limited to 'tensorflow/python/tools/strip_unused_lib.py')
-rw-r--r-- | tensorflow/python/tools/strip_unused_lib.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/tools/strip_unused_lib.py b/tensorflow/python/tools/strip_unused_lib.py index 8f9e20ab8e..b1d1956076 100644 --- a/tensorflow/python/tools/strip_unused_lib.py +++ b/tensorflow/python/tools/strip_unused_lib.py @@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names, a list that specifies one value per input node name. Returns: - A GraphDef with all unnecessary ops removed. + A `GraphDef` with all unnecessary ops removed. + + Raises: + ValueError: If any element in `input_node_names` refers to a tensor instead + of an operation. + KeyError: If any element in `input_node_names` is not found in the graph. """ + for name in input_node_names: + if ":" in name: + raise ValueError("Name '%s' appears to refer to a Tensor, " + "not a Operation." % name) + # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). + not_found = {name for name in input_node_names} inputs_replaced_graph_def = graph_pb2.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: + not_found.remove(node.name) placeholder_node = node_def_pb2.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name @@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names, else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) + if not_found: + raise KeyError("The following input nodes were not found: %s\n" % not_found) + output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def |