aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-08-07 14:15:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-07 15:32:34 -0700
commit7c7014fd41cdf4e24f923b9e79c249d717aa508f (patch)
tree6b251a2f01452b6aeccebeace94bb5df706c5f48 /tensorflow/contrib/quantization
parentf93960d0afdcf59457b614158ee5575ca2acfe15 (diff)
Optimizing graphs for inference.
Change: 129581148
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph.py74
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph_test.py64
2 files changed, 2 insertions, 136 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py
index d999797f81..60cbca9834 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph.py
@@ -327,7 +327,7 @@ class GraphRewriter(object):
for output_node in output_nodes:
self.quantize_nodes_recursively(output_node)
elif self.mode == "eightbit":
- self.set_input_graph(self.remove_unneeded_nodes(self.input_graph))
+ self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
self.already_visited = {}
self.layers_eightbitized = []
for output_node in output_nodes:
@@ -963,78 +963,6 @@ class GraphRewriter(object):
output_graph.node.extend([output_node])
return output_graph
- def remove_unneeded_nodes(self, input_graph):
- """Prunes out nodes that aren't needed for inference.
-
- There are nodes like Identity and CheckNumerics that are only useful
- during training, and can be removed in graphs that will be used for
- nothing but inference. Here we identify and remove them, returning an
- equivalent graph.
-
- Args:
- input_graph: Model to analyze and prune.
-
- Returns:
- A list of nodes with the unnecessary ones removed.
- """
-
- types_to_remove = {"CheckNumerics": True}
-
- input_nodes = input_graph.node
- names_to_remove = {}
- for node in input_nodes:
- if node.op in types_to_remove:
- names_to_remove[node.name] = True
-
- nodes_after_removal = []
- for node in input_nodes:
- if node.name in names_to_remove:
- continue
- new_node = tf.NodeDef()
- new_node.CopyFrom(node)
- input_before_removal = node.input
- del new_node.input[:]
- for full_input_name in input_before_removal:
- input_name = re.sub(r"^\^", "", full_input_name)
- if input_name in names_to_remove:
- continue
- new_node.input.append(full_input_name)
- nodes_after_removal.append(new_node)
-
- types_to_splice = {"Identity": True}
- names_to_splice = {}
- for node in nodes_after_removal:
- if node.op in types_to_splice:
- # We don't want to remove nodes that have control edge inputs, because
- # they might be involved in subtle dependency issues that removing them
- # will jeopardize.
- has_control_edge = False
- for input_name in node.input:
- if re.match(r"^\^", input_name):
- has_control_edge = True
- if not has_control_edge:
- names_to_splice[node.name] = node.input[0]
-
- nodes_after_splicing = []
- for node in nodes_after_removal:
- if node.name in names_to_splice:
- continue
- new_node = tf.NodeDef()
- new_node.CopyFrom(node)
- input_before_removal = node.input
- del new_node.input[:]
- for full_input_name in input_before_removal:
- input_name = re.sub(r"^\^", "", full_input_name)
- if input_name in names_to_splice:
- new_node.input.append(names_to_splice[input_name])
- else:
- new_node.input.append(full_input_name)
- nodes_after_splicing.append(new_node)
-
- output_graph = tf.GraphDef()
- output_graph.node.extend(nodes_after_splicing)
- return output_graph
-
def set_input_graph(self, new_input_graph):
self.input_graph = new_input_graph
self.nodes_map = self.create_nodes_map(self.input_graph)
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
index df3eac5f2c..4826ea2689 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
@@ -286,67 +286,6 @@ class QuantizeGraphTest(tf.test.TestCase):
test_graph(float_graph_def, {}, [concat_name])
- def test_remove_unneeded_nodes(self):
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- a_check_name = "a_check"
- b_check_name = "b_check"
- a_identity_name = "a_identity"
- b_identity_name = "b_identity"
- add_name = "add"
- graph_def = tf.GraphDef()
- a_constant = quantize_graph.create_constant_node(a_constant_name,
- value=1,
- dtype=tf.float32,
- shape=[])
- graph_def.node.extend([a_constant])
- a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
- [a_constant_name])
- graph_def.node.extend([a_check_node])
- a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
- [a_constant_name,
- "^" + a_check_name])
- graph_def.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(b_constant_name,
- value=1,
- dtype=tf.float32,
- shape=[])
- graph_def.node.extend([b_constant])
- b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
- [b_constant_name])
- graph_def.node.extend([b_check_node])
- b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
- [b_constant_name,
- "^" + b_check_name])
- graph_def.node.extend([b_identity_node])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name,
- b_identity_name])
- quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
- graph_def.node.extend([add_node])
-
- expected_output = tf.GraphDef()
- a_constant = quantize_graph.create_constant_node(a_constant_name,
- value=1,
- dtype=tf.float32,
- shape=[])
- expected_output.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(b_constant_name,
- value=1,
- dtype=tf.float32,
- shape=[])
- expected_output.node.extend([b_constant])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_constant_name,
- b_constant_name])
- quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
- expected_output.node.extend([add_node])
-
- rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
- output = rewriter.remove_unneeded_nodes(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [add_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
def test_multiple_outputs(self):
input_constant_name = "input_constant"
split_constant_name = "split_constant"
@@ -479,8 +418,7 @@ class QuantizeGraphTest(tf.test.TestCase):
quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
expected_output.node.extend([add_node])
- rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
- output = rewriter.remove_unneeded_nodes(graph_def)
+ output = graph_util.remove_training_nodes(graph_def)
stripped_output = graph_util.extract_sub_graph(output, [add_name])
self.assertProtoEquals(expected_output, stripped_output)