aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-07 07:38:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 08:47:45 -0700
commit5328a426fe2d76dabd833e774711b2d56f13f9a8 (patch)
tree0eecf1c011105d39507387a609c3ea88c717ab96 /tensorflow/contrib/quantization
parenta58b45b28f27995bf422057a6c4fc2228db74dde (diff)
In quantize_graph.py, reset output_nodes after set_input_graph is called, in
case an output node was rewritten because its input was removed. Change: 135479283
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph.py10
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph_test.py9
2 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py
index 5ded556691..d9982a5cb1 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph.py
@@ -328,6 +328,8 @@ class GraphRewriter(object):
self.quantize_nodes_recursively(output_node)
elif self.mode == "eightbit":
self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
+ output_nodes = [self.nodes_map[output_node_name]
+ for output_node_name in output_node_names]
self.already_visited = {}
self.layers_eightbitized = []
for output_node in output_nodes:
@@ -350,11 +352,11 @@ class GraphRewriter(object):
def round_nodes_recursively(self, current_node):
"""The entry point for simple rounding quantization."""
+ if self.already_visited[current_node.name]:
+ return
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
- if input_node_name in self.already_visited:
- continue
input_node = self.nodes_map[input_node_name]
self.round_nodes_recursively(input_node)
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
@@ -381,11 +383,11 @@ class GraphRewriter(object):
def quantize_nodes_recursively(self, current_node):
"""The entry point for quantizing nodes to eight bit and back."""
+ if self.already_visited[current_node.name]:
+ return
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
- if input_node_name in self.already_visited:
- continue
input_node = self.nodes_map[input_node_name]
self.quantize_nodes_recursively(input_node)
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
index 4826ea2689..24009e7a5a 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
@@ -350,7 +350,14 @@ class QuantizeGraphTest(tf.test.TestCase):
[input_constant_name])
quantize_graph.set_attr_dtype(identity_node, "T", tf.float32)
float_graph_def.node.extend([identity_node])
- test_graph(float_graph_def, {}, [identity_name])
+
+ mul_name = "mul"
+ mul_node = quantize_graph.create_node("Mul", mul_name,
+ [identity_name, identity_name])
+ quantize_graph.set_attr_dtype(mul_node, "T", tf.float32)
+ float_graph_def.node.extend([mul_node])
+
+ test_graph(float_graph_def, {}, [mul_name])
def test_keep_control_edges(self):
no_op_name = "no_op"