diff options
author | 2017-11-06 14:45:54 -0800 | |
---|---|---|
committer | 2017-11-06 14:50:57 -0800 | |
commit | bf806939cd3a75181d6e317a6ea1f7f56c340eda (patch) | |
tree | 667857a2913671719b9feb5c4d2c1ed007cd5eae | |
parent | 389eaade3fd29426f4263430e20384ec1ae912be (diff) |
Create a routine that can collapse a subgraph into a fused op
PiperOrigin-RevId: 174765540
5 files changed, 256 insertions, 35 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 4d81564af8..e8dad886a1 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -24,6 +24,7 @@ tf_custom_op_py_library( "python/framework/__init__.py", "python/framework/checkpoint_utils.py", "python/framework/experimental.py", + "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", "python/ops/accumulate_n_v2.py", @@ -233,6 +234,17 @@ py_test( ) py_test( + name = "graph_util_test", + srcs = ["python/framework/graph_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + ], +) + +py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py index c8e6a46854..2d49771ab7 100644 --- a/tensorflow/contrib/framework/python/framework/__init__.py +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.experimental import experimental +from tensorflow.contrib.framework.python.framework.graph_util import * from tensorflow.contrib.framework.python.framework.tensor_util import * # pylint: enable=wildcard-import from tensorflow.python.util import decorator_utils diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py new file mode 100644 index 0000000000..8ab8711db4 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to manipulate a tensor graph in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +import six + +# pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary +from tensorflow.python.framework.graph_util_impl import _node_name + +__all__ = ["fuse_op"] + + +def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, + output_quantized, op_name, op_type): + """Fuse subgraph between input_nodes and output_nodes into a single custom op. + + Args: + graph_def: A graph_pb2.GraphDef proto. + input_nodes: input nodes to the subgraph to be fused. + output_nodes: output nodes to the subgraph to be fused. + output_dtypes: A list of output datatypes for the custom op + output_quantized: A boolean flag that indicates if output is quantized + op_name: fused op name. + op_type: fused op type. + Returns: + The GraphDef of the new graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + if isinstance(input_nodes, six.string_types): + raise TypeError("input_nodes must be a list.") + + if isinstance(output_nodes, six.string_types): + raise TypeError("output_nodes must be a list.") + + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) + + # Nodes upto and including input_nodes + reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) + # Nodes upto and including output_nodes + reachable_by_output = _bfs_for_reachable_nodes(output_nodes, + name_to_input_name) + + # Set of nodes in the list input_nodes + input_nodes_set = set(input_nodes) + + # Set of nodes in the list output_nodes + output_nodes_set = set(output_nodes) + + nodes_post_output = [] + for node in graph_def.node: + n = _node_name(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is between input and output, i.e., part of the fused op + next_to_visit = [n] + while next_to_visit: + cur_node = next_to_visit[0] + del next_to_visit[0] + if cur_node in reachable_by_input and cur_node not in input_nodes_set: + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, cur_node)) + if cur_node not in input_nodes_set: + next_to_visit += name_to_input_name[cur_node] + else: + nodes_post_output.append(n) + + # Add all nodes upto the input nodes + out = graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([copy.deepcopy(name_to_node[node])]) + + # Add the custom op + new_node = node_def_pb2.NodeDef() + for node in input_nodes: + new_node.input.append(node) + new_node.attr["_output_types"].list.type[:] = output_dtypes + new_node.attr["_output_quantized"].b = output_quantized + new_node.op = op_type + new_node.name = op_name + out.node.extend([new_node]) + + # Add the nodes in the output of the custom op + for index, n in enumerate(output_nodes): + assert len(name_to_node[n].input) == 1 + new_node = copy.deepcopy(name_to_node[n]) + del new_node.input[:] + new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) + out.node.extend([new_node]) + + # Add the nodes post output_nodes + for n in nodes_post_output: + out.node.extend([copy.deepcopy(name_to_node[n])]) + + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + return out diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py new file mode 100644 index 0000000000..87b992e22e --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""@graph_util tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import graph_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +def GetNewNode(name, op, input_nodes): + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for node in input_nodes: + new_node.input.append(node) + return new_node + + +class GraphUtilTest(test.TestCase): + + def testGraphUtil(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op( + graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2') + self.assertEqual(len(fused_graph_def.node), 4) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[1].input[0], 'A') + self.assertEqual(fused_graph_def.node[1].op, 'Op2') + self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[2].name, 'D') + self.assertEqual(fused_graph_def.node[3].name, 'E') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index ce85747d7c..6c7b455388 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Helpers to manipulate a tensor graph in python. """ @@ -108,6 +107,46 @@ def _node_name(n): return n.split(":")[0] +def _extract_graph_summary(graph_def): + """Extracts useful information from the graph and returns them.""" + name_to_input_name = {} # Keyed by the dest node name. + name_to_node = {} # Keyed by node name. + + # Keeps track of node sequences. It is important to still output the + # operations in the original order. + name_to_seq_num = {} # Keyed by node name. + seq = 0 + for node in graph_def.node: + n = _node_name(node.name) + name_to_node[n] = node + name_to_input_name[n] = [_node_name(x) for x in node.input] + name_to_seq_num[n] = seq + seq += 1 + return name_to_input_name, name_to_node, name_to_seq_num + + +def _assert_nodes_are_present(name_to_node, nodes): + """Assert that nodes are present in the graph.""" + for d in nodes: + assert d in name_to_node, "%s is not in graph" % d + + +def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): + """Breadth first search for reachable nodes from target nodes.""" + nodes_to_keep = set() + # Breadth first search to find all the nodes that we should keep. + next_to_visit = target_nodes[:] + while next_to_visit: + n = next_to_visit[0] + del next_to_visit[0] + if n in nodes_to_keep: + # Already visited this node. + continue + nodes_to_keep.add(n) + next_to_visit += name_to_input_name[n] + return nodes_to_keep + + def extract_sub_graph(graph_def, dest_nodes): """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. @@ -127,40 +166,18 @@ def extract_sub_graph(graph_def, dest_nodes): if isinstance(dest_nodes, six.string_types): raise TypeError("dest_nodes must be a list.") - edges = {} # Keyed by the dest node name. - name_to_node_map = {} # Keyed by node name. - - # Keeps track of node sequences. It is important to still output the - # operations in the original order. - node_seq = {} # Keyed by node name. - seq = 0 - for node in graph_def.node: - n = _node_name(node.name) - name_to_node_map[n] = node - edges[n] = [_node_name(x) for x in node.input] - node_seq[n] = seq - seq += 1 - - for d in dest_nodes: - assert d in name_to_node_map, "%s is not in graph" % d + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, dest_nodes) - nodes_to_keep = set() - # Breadth first search to find all the nodes that we should keep. - next_to_visit = dest_nodes[:] - while next_to_visit: - n = next_to_visit[0] - del next_to_visit[0] - if n in nodes_to_keep: - # Already visited this node. - continue - nodes_to_keep.add(n) - next_to_visit += edges[n] + nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name) - nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n]) + nodes_to_keep_list = sorted( + list(nodes_to_keep), key=lambda n: name_to_seq_num[n]) # Now construct the output GraphDef out = graph_pb2.GraphDef() for n in nodes_to_keep_list: - out.node.extend([copy.deepcopy(name_to_node_map[n])]) + out.node.extend([copy.deepcopy(name_to_node[n])]) out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) @@ -181,7 +198,9 @@ def tensor_shape_from_node_def_name(graph, input_name): return shape -def convert_variables_to_constants(sess, input_graph_def, output_node_names, +def convert_variables_to_constants(sess, + input_graph_def, + output_node_names, variable_names_whitelist=None, variable_names_blacklist=None): """Replaces all the variables in a graph with constants of the same values. @@ -237,10 +256,10 @@ def convert_variables_to_constants(sess, input_graph_def, output_node_names, dtype = input_node.attr["dtype"] data = found_variables[input_node.name] output_node.attr["dtype"].CopyFrom(dtype) - output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue( - tensor=tensor_util.make_tensor_proto(data, - dtype=dtype.type, - shape=data.shape))) + output_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto( + data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 else: output_node.CopyFrom(input_node) |