aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization/tools/quantize_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantization/tools/quantize_graph.py')
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py
index 60cbca9834..0a814dadae 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph.py
@@ -66,12 +66,12 @@ flags.DEFINE_boolean("load_quantization_so", True,
def print_input_nodes(current_node, nodes_map, indent, already_visited):
print(" " * indent + current_node.op + ":" + current_node.name)
+ already_visited[current_node.name] = True
for input_node_name in current_node.input:
if input_node_name in already_visited:
continue
input_node = nodes_map[input_node_name]
print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
- already_visited[current_node.name] = True
def create_node(op, name, inputs):
@@ -350,13 +350,13 @@ class GraphRewriter(object):
def round_nodes_recursively(self, current_node):
"""The entry point for simple rounding quantization."""
+ self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.round_nodes_recursively(input_node)
- self.already_visited[current_node.name] = True
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
if any(current_node.op in s for s in nodes_to_quantize):
new_node = tf.NodeDef()
@@ -381,13 +381,13 @@ class GraphRewriter(object):
def quantize_nodes_recursively(self, current_node):
"""The entry point for quantizing nodes to eight bit and back."""
+ self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.quantize_nodes_recursively(input_node)
- self.already_visited[current_node.name] = True
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
if any(current_node.op in s for s in nodes_to_quantize):
for input_name in current_node.input:
@@ -448,13 +448,13 @@ class GraphRewriter(object):
def eightbitize_nodes_recursively(self, current_node):
"""The entry point for transforming a graph into full eight bit."""
+ self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.eightbitize_nodes_recursively(input_node)
- self.already_visited[current_node.name] = True
if current_node.op == "MatMul":
self.eightbitize_mat_mul_node(current_node)
elif current_node.op == "Conv2D":