aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-06 14:45:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 14:50:57 -0800
commitbf806939cd3a75181d6e317a6ea1f7f56c340eda (patch)
tree667857a2913671719b9feb5c4d2c1ed007cd5eae
parent389eaade3fd29426f4263430e20384ec1ae912be (diff)
Create a routine that can collapse a subgraph into a fused op
PiperOrigin-RevId: 174765540
-rw-r--r--tensorflow/contrib/framework/BUILD12
-rw-r--r--tensorflow/contrib/framework/python/framework/__init__.py1
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util.py128
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util_test.py61
-rw-r--r--tensorflow/python/framework/graph_util_impl.py89
5 files changed, 256 insertions, 35 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index 4d81564af8..e8dad886a1 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -24,6 +24,7 @@ tf_custom_op_py_library(
"python/framework/__init__.py",
"python/framework/checkpoint_utils.py",
"python/framework/experimental.py",
+ "python/framework/graph_util.py",
"python/framework/tensor_util.py",
"python/ops/__init__.py",
"python/ops/accumulate_n_v2.py",
@@ -233,6 +234,17 @@ py_test(
)
py_test(
+ name = "graph_util_test",
+ srcs = ["python/framework/graph_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
name = "tensor_util_test",
srcs = ["python/framework/tensor_util_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py
index c8e6a46854..2d49771ab7 100644
--- a/tensorflow/contrib/framework/python/framework/__init__.py
+++ b/tensorflow/contrib/framework/python/framework/__init__.py
@@ -21,6 +21,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
from tensorflow.contrib.framework.python.framework.experimental import experimental
+from tensorflow.contrib.framework.python.framework.graph_util import *
from tensorflow.contrib.framework.python.framework.tensor_util import *
# pylint: enable=wildcard-import
from tensorflow.python.util import decorator_utils
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py
new file mode 100644
index 0000000000..8ab8711db4
--- /dev/null
+++ b/tensorflow/contrib/framework/python/framework/graph_util.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers to manipulate a tensor graph in python.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import copy
+import six
+
+# pylint: disable=unused-import
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present
+from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
+from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
+from tensorflow.python.framework.graph_util_impl import _node_name
+
+__all__ = ["fuse_op"]
+
+
+def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
+ output_quantized, op_name, op_type):
+ """Fuse subgraph between input_nodes and output_nodes into a single custom op.
+
+ Args:
+ graph_def: A graph_pb2.GraphDef proto.
+ input_nodes: input nodes to the subgraph to be fused.
+ output_nodes: output nodes to the subgraph to be fused.
+ output_dtypes: A list of output datatypes for the custom op
+ output_quantized: A boolean flag that indicates if output is quantized
+ op_name: fused op name.
+ op_type: fused op type.
+ Returns:
+ The GraphDef of the new graph.
+
+ Raises:
+ TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
+ """
+
+ if not isinstance(graph_def, graph_pb2.GraphDef):
+ raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
+
+ if isinstance(input_nodes, six.string_types):
+ raise TypeError("input_nodes must be a list.")
+
+ if isinstance(output_nodes, six.string_types):
+ raise TypeError("output_nodes must be a list.")
+
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ graph_def)
+ _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)
+
+ # Nodes upto and including input_nodes
+ reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
+ # Nodes upto and including output_nodes
+ reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
+ name_to_input_name)
+
+ # Set of nodes in the list input_nodes
+ input_nodes_set = set(input_nodes)
+
+ # Set of nodes in the list output_nodes
+ output_nodes_set = set(output_nodes)
+
+ nodes_post_output = []
+ for node in graph_def.node:
+ n = _node_name(node.name)
+ if n in reachable_by_output:
+ if n not in reachable_by_input and n not in output_nodes_set:
+ # n is between input and output, i.e., part of the fused op
+ next_to_visit = [n]
+ while next_to_visit:
+ cur_node = next_to_visit[0]
+ del next_to_visit[0]
+ if cur_node in reachable_by_input and cur_node not in input_nodes_set:
+ raise TypeError("Node %s uses input %s not in input_nodes." %
+ (n, cur_node))
+ if cur_node not in input_nodes_set:
+ next_to_visit += name_to_input_name[cur_node]
+ else:
+ nodes_post_output.append(n)
+
+ # Add all nodes upto the input nodes
+ out = graph_pb2.GraphDef()
+ reachable_by_input_sorted = sorted(
+ list(reachable_by_input), key=lambda n: name_to_seq_num[n])
+ for node in reachable_by_input_sorted:
+ out.node.extend([copy.deepcopy(name_to_node[node])])
+
+ # Add the custom op
+ new_node = node_def_pb2.NodeDef()
+ for node in input_nodes:
+ new_node.input.append(node)
+ new_node.attr["_output_types"].list.type[:] = output_dtypes
+ new_node.attr["_output_quantized"].b = output_quantized
+ new_node.op = op_type
+ new_node.name = op_name
+ out.node.extend([new_node])
+
+ # Add the nodes in the output of the custom op
+ for index, n in enumerate(output_nodes):
+ assert len(name_to_node[n].input) == 1
+ new_node = copy.deepcopy(name_to_node[n])
+ del new_node.input[:]
+ new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
+ out.node.extend([new_node])
+
+ # Add the nodes post output_nodes
+ for n in nodes_post_output:
+ out.node.extend([copy.deepcopy(name_to_node[n])])
+
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+ return out
diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py
new file mode 100644
index 0000000000..87b992e22e
--- /dev/null
+++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""@graph_util tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.framework.python.framework import graph_util
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import test
+
+
+def GetNewNode(name, op, input_nodes):
+ new_node = node_def_pb2.NodeDef()
+ new_node.op = op
+ new_node.name = name
+ for node in input_nodes:
+ new_node.input.append(node)
+ return new_node
+
+
+class GraphUtilTest(test.TestCase):
+
+ def testGraphUtil(self):
+ graph_def = graph_pb2.GraphDef()
+ node_a = GetNewNode('A', 'Placeholder', [])
+ node_b = GetNewNode('B', 'Op1', ['A'])
+ node_c = GetNewNode('C', 'Op1', ['B'])
+ node_d = GetNewNode('D', 'Op1', ['C'])
+ node_e = GetNewNode('E', 'Op1', ['D'])
+ graph_def.node.extend([node_a, node_b, node_c, node_d, node_e])
+ fused_graph_def = graph_util.fuse_op(
+ graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2')
+ self.assertEqual(len(fused_graph_def.node), 4)
+ self.assertEqual(fused_graph_def.node[0].name, 'A')
+ self.assertEqual(fused_graph_def.node[1].name, 'FusedOp')
+ self.assertEqual(fused_graph_def.node[1].input[0], 'A')
+ self.assertEqual(fused_graph_def.node[1].op, 'Op2')
+ self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True)
+ self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type,
+ [types_pb2.DT_FLOAT])
+ self.assertEqual(fused_graph_def.node[2].name, 'D')
+ self.assertEqual(fused_graph_def.node[3].name, 'E')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index ce85747d7c..6c7b455388 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Helpers to manipulate a tensor graph in python.
"""
@@ -108,6 +107,46 @@ def _node_name(n):
return n.split(":")[0]
+def _extract_graph_summary(graph_def):
+ """Extracts useful information from the graph and returns them."""
+ name_to_input_name = {} # Keyed by the dest node name.
+ name_to_node = {} # Keyed by node name.
+
+ # Keeps track of node sequences. It is important to still output the
+ # operations in the original order.
+ name_to_seq_num = {} # Keyed by node name.
+ seq = 0
+ for node in graph_def.node:
+ n = _node_name(node.name)
+ name_to_node[n] = node
+ name_to_input_name[n] = [_node_name(x) for x in node.input]
+ name_to_seq_num[n] = seq
+ seq += 1
+ return name_to_input_name, name_to_node, name_to_seq_num
+
+
+def _assert_nodes_are_present(name_to_node, nodes):
+ """Assert that nodes are present in the graph."""
+ for d in nodes:
+ assert d in name_to_node, "%s is not in graph" % d
+
+
+def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
+ """Breadth first search for reachable nodes from target nodes."""
+ nodes_to_keep = set()
+ # Breadth first search to find all the nodes that we should keep.
+ next_to_visit = target_nodes[:]
+ while next_to_visit:
+ n = next_to_visit[0]
+ del next_to_visit[0]
+ if n in nodes_to_keep:
+ # Already visited this node.
+ continue
+ nodes_to_keep.add(n)
+ next_to_visit += name_to_input_name[n]
+ return nodes_to_keep
+
+
def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
@@ -127,40 +166,18 @@ def extract_sub_graph(graph_def, dest_nodes):
if isinstance(dest_nodes, six.string_types):
raise TypeError("dest_nodes must be a list.")
- edges = {} # Keyed by the dest node name.
- name_to_node_map = {} # Keyed by node name.
-
- # Keeps track of node sequences. It is important to still output the
- # operations in the original order.
- node_seq = {} # Keyed by node name.
- seq = 0
- for node in graph_def.node:
- n = _node_name(node.name)
- name_to_node_map[n] = node
- edges[n] = [_node_name(x) for x in node.input]
- node_seq[n] = seq
- seq += 1
-
- for d in dest_nodes:
- assert d in name_to_node_map, "%s is not in graph" % d
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ graph_def)
+ _assert_nodes_are_present(name_to_node, dest_nodes)
- nodes_to_keep = set()
- # Breadth first search to find all the nodes that we should keep.
- next_to_visit = dest_nodes[:]
- while next_to_visit:
- n = next_to_visit[0]
- del next_to_visit[0]
- if n in nodes_to_keep:
- # Already visited this node.
- continue
- nodes_to_keep.add(n)
- next_to_visit += edges[n]
+ nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
- nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
+ nodes_to_keep_list = sorted(
+ list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
# Now construct the output GraphDef
out = graph_pb2.GraphDef()
for n in nodes_to_keep_list:
- out.node.extend([copy.deepcopy(name_to_node_map[n])])
+ out.node.extend([copy.deepcopy(name_to_node[n])])
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
@@ -181,7 +198,9 @@ def tensor_shape_from_node_def_name(graph, input_name):
return shape
-def convert_variables_to_constants(sess, input_graph_def, output_node_names,
+def convert_variables_to_constants(sess,
+ input_graph_def,
+ output_node_names,
variable_names_whitelist=None,
variable_names_blacklist=None):
"""Replaces all the variables in a graph with constants of the same values.
@@ -237,10 +256,10 @@ def convert_variables_to_constants(sess, input_graph_def, output_node_names,
dtype = input_node.attr["dtype"]
data = found_variables[input_node.name]
output_node.attr["dtype"].CopyFrom(dtype)
- output_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(
- tensor=tensor_util.make_tensor_proto(data,
- dtype=dtype.type,
- shape=data.shape)))
+ output_node.attr["value"].CopyFrom(
+ attr_value_pb2.AttrValue(
+ tensor=tensor_util.make_tensor_proto(
+ data, dtype=dtype.type, shape=data.shape)))
how_many_converted += 1
else:
output_node.CopyFrom(input_node)