diff options
author | Pete Warden <petewarden@google.com> | 2016-08-07 14:15:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-07 15:32:34 -0700 |
commit | 7c7014fd41cdf4e24f923b9e79c249d717aa508f (patch) | |
tree | 6b251a2f01452b6aeccebeace94bb5df706c5f48 /tensorflow/contrib/quantization | |
parent | f93960d0afdcf59457b614158ee5575ca2acfe15 (diff) |
Optimizing graphs for inference.
Change: 129581148
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph.py | 74 | ||||
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph_test.py | 64 |
2 files changed, 2 insertions, 136 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py index d999797f81..60cbca9834 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph.py @@ -327,7 +327,7 @@ class GraphRewriter(object): for output_node in output_nodes: self.quantize_nodes_recursively(output_node) elif self.mode == "eightbit": - self.set_input_graph(self.remove_unneeded_nodes(self.input_graph)) + self.set_input_graph(graph_util.remove_training_nodes(self.input_graph)) self.already_visited = {} self.layers_eightbitized = [] for output_node in output_nodes: @@ -963,78 +963,6 @@ class GraphRewriter(object): output_graph.node.extend([output_node]) return output_graph - def remove_unneeded_nodes(self, input_graph): - """Prunes out nodes that aren't needed for inference. - - There are nodes like Identity and CheckNumerics that are only useful - during training, and can be removed in graphs that will be used for - nothing but inference. Here we identify and remove them, returning an - equivalent graph. - - Args: - input_graph: Model to analyze and prune. - - Returns: - A list of nodes with the unnecessary ones removed. - """ - - types_to_remove = {"CheckNumerics": True} - - input_nodes = input_graph.node - names_to_remove = {} - for node in input_nodes: - if node.op in types_to_remove: - names_to_remove[node.name] = True - - nodes_after_removal = [] - for node in input_nodes: - if node.name in names_to_remove: - continue - new_node = tf.NodeDef() - new_node.CopyFrom(node) - input_before_removal = node.input - del new_node.input[:] - for full_input_name in input_before_removal: - input_name = re.sub(r"^\^", "", full_input_name) - if input_name in names_to_remove: - continue - new_node.input.append(full_input_name) - nodes_after_removal.append(new_node) - - types_to_splice = {"Identity": True} - names_to_splice = {} - for node in nodes_after_removal: - if node.op in types_to_splice: - # We don't want to remove nodes that have control edge inputs, because - # they might be involved in subtle dependency issues that removing them - # will jeopardize. - has_control_edge = False - for input_name in node.input: - if re.match(r"^\^", input_name): - has_control_edge = True - if not has_control_edge: - names_to_splice[node.name] = node.input[0] - - nodes_after_splicing = [] - for node in nodes_after_removal: - if node.name in names_to_splice: - continue - new_node = tf.NodeDef() - new_node.CopyFrom(node) - input_before_removal = node.input - del new_node.input[:] - for full_input_name in input_before_removal: - input_name = re.sub(r"^\^", "", full_input_name) - if input_name in names_to_splice: - new_node.input.append(names_to_splice[input_name]) - else: - new_node.input.append(full_input_name) - nodes_after_splicing.append(new_node) - - output_graph = tf.GraphDef() - output_graph.node.extend(nodes_after_splicing) - return output_graph - def set_input_graph(self, new_input_graph): self.input_graph = new_input_graph self.nodes_map = self.create_nodes_map(self.input_graph) diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py index df3eac5f2c..4826ea2689 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py @@ -286,67 +286,6 @@ class QuantizeGraphTest(tf.test.TestCase): test_graph(float_graph_def, {}, [concat_name]) - def test_remove_unneeded_nodes(self): - a_constant_name = "a_constant" - b_constant_name = "b_constant" - a_check_name = "a_check" - b_check_name = "b_check" - a_identity_name = "a_identity" - b_identity_name = "b_identity" - add_name = "add" - graph_def = tf.GraphDef() - a_constant = quantize_graph.create_constant_node(a_constant_name, - value=1, - dtype=tf.float32, - shape=[]) - graph_def.node.extend([a_constant]) - a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, - [a_constant_name]) - graph_def.node.extend([a_check_node]) - a_identity_node = quantize_graph.create_node("Identity", a_identity_name, - [a_constant_name, - "^" + a_check_name]) - graph_def.node.extend([a_identity_node]) - b_constant = quantize_graph.create_constant_node(b_constant_name, - value=1, - dtype=tf.float32, - shape=[]) - graph_def.node.extend([b_constant]) - b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, - [b_constant_name]) - graph_def.node.extend([b_check_node]) - b_identity_node = quantize_graph.create_node("Identity", b_identity_name, - [b_constant_name, - "^" + b_check_name]) - graph_def.node.extend([b_identity_node]) - add_node = quantize_graph.create_node("Add", add_name, - [a_identity_name, - b_identity_name]) - quantize_graph.set_attr_dtype(add_node, "T", tf.float32) - graph_def.node.extend([add_node]) - - expected_output = tf.GraphDef() - a_constant = quantize_graph.create_constant_node(a_constant_name, - value=1, - dtype=tf.float32, - shape=[]) - expected_output.node.extend([a_constant]) - b_constant = quantize_graph.create_constant_node(b_constant_name, - value=1, - dtype=tf.float32, - shape=[]) - expected_output.node.extend([b_constant]) - add_node = quantize_graph.create_node("Add", add_name, - [a_constant_name, - b_constant_name]) - quantize_graph.set_attr_dtype(add_node, "T", tf.float32) - expected_output.node.extend([add_node]) - - rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) - output = rewriter.remove_unneeded_nodes(graph_def) - stripped_output = graph_util.extract_sub_graph(output, [add_name]) - self.assertProtoEquals(expected_output, stripped_output) - def test_multiple_outputs(self): input_constant_name = "input_constant" split_constant_name = "split_constant" @@ -479,8 +418,7 @@ class QuantizeGraphTest(tf.test.TestCase): quantize_graph.set_attr_dtype(add_node, "T", tf.float32) expected_output.node.extend([add_node]) - rewriter = quantize_graph.GraphRewriter(graph_def, [add_name]) - output = rewriter.remove_unneeded_nodes(graph_def) + output = graph_util.remove_training_nodes(graph_def) stripped_output = graph_util.extract_sub_graph(output, [add_name]) self.assertProtoEquals(expected_output, stripped_output) |