diff options
author | 2016-10-07 07:38:22 -0800 | |
---|---|---|
committer | 2016-10-07 08:47:45 -0700 | |
commit | 5328a426fe2d76dabd833e774711b2d56f13f9a8 (patch) | |
tree | 0eecf1c011105d39507387a609c3ea88c717ab96 /tensorflow/contrib/quantization | |
parent | a58b45b28f27995bf422057a6c4fc2228db74dde (diff) |
In quantize_graph.py, reset output_nodes after set_input_graph is called, in
case an output node was rewritten because its input was removed.
Change: 135479283
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph_test.py | 9 |
2 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py index 5ded556691..d9982a5cb1 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph.py @@ -328,6 +328,8 @@ class GraphRewriter(object): self.quantize_nodes_recursively(output_node) elif self.mode == "eightbit": self.set_input_graph(graph_util.remove_training_nodes(self.input_graph)) + output_nodes = [self.nodes_map[output_node_name] + for output_node_name in output_node_names] self.already_visited = {} self.layers_eightbitized = [] for output_node in output_nodes: @@ -350,11 +352,11 @@ class GraphRewriter(object): def round_nodes_recursively(self, current_node): """The entry point for simple rounding quantization.""" + if self.already_visited[current_node.name]: + return self.already_visited[current_node.name] = True for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) - if input_node_name in self.already_visited: - continue input_node = self.nodes_map[input_node_name] self.round_nodes_recursively(input_node) nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] @@ -381,11 +383,11 @@ class GraphRewriter(object): def quantize_nodes_recursively(self, current_node): """The entry point for quantizing nodes to eight bit and back.""" + if self.already_visited[current_node.name]: + return self.already_visited[current_node.name] = True for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) - if input_node_name in self.already_visited: - continue input_node = self.nodes_map[input_node_name] self.quantize_nodes_recursively(input_node) nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py index 4826ea2689..24009e7a5a 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py @@ -350,7 +350,14 @@ class QuantizeGraphTest(tf.test.TestCase): [input_constant_name]) quantize_graph.set_attr_dtype(identity_node, "T", tf.float32) float_graph_def.node.extend([identity_node]) - test_graph(float_graph_def, {}, [identity_name]) + + mul_name = "mul" + mul_node = quantize_graph.create_node("Mul", mul_name, + [identity_name, identity_name]) + quantize_graph.set_attr_dtype(mul_node, "T", tf.float32) + float_graph_def.node.extend([mul_node]) + + test_graph(float_graph_def, {}, [mul_name]) def test_keep_control_edges(self): no_op_name = "no_op" |