diff options
Diffstat (limited to 'tensorflow/contrib/quantization/tools/quantize_graph.py')
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py index 60cbca9834..0a814dadae 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph.py @@ -66,12 +66,12 @@ flags.DEFINE_boolean("load_quantization_so", True, def print_input_nodes(current_node, nodes_map, indent, already_visited): print(" " * indent + current_node.op + ":" + current_node.name) + already_visited[current_node.name] = True for input_node_name in current_node.input: if input_node_name in already_visited: continue input_node = nodes_map[input_node_name] print_input_nodes(input_node, nodes_map, indent + 1, already_visited) - already_visited[current_node.name] = True def create_node(op, name, inputs): @@ -350,13 +350,13 @@ class GraphRewriter(object): def round_nodes_recursively(self, current_node): """The entry point for simple rounding quantization.""" + self.already_visited[current_node.name] = True for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) if input_node_name in self.already_visited: continue input_node = self.nodes_map[input_node_name] self.round_nodes_recursively(input_node) - self.already_visited[current_node.name] = True nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] if any(current_node.op in s for s in nodes_to_quantize): new_node = tf.NodeDef() @@ -381,13 +381,13 @@ class GraphRewriter(object): def quantize_nodes_recursively(self, current_node): """The entry point for quantizing nodes to eight bit and back.""" + self.already_visited[current_node.name] = True for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) if input_node_name in self.already_visited: continue input_node = self.nodes_map[input_node_name] self.quantize_nodes_recursively(input_node) - self.already_visited[current_node.name] = True nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] if any(current_node.op in s for s in nodes_to_quantize): for input_name in current_node.input: @@ -448,13 +448,13 @@ class GraphRewriter(object): def eightbitize_nodes_recursively(self, current_node): """The entry point for transforming a graph into full eight bit.""" + self.already_visited[current_node.name] = True for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) if input_node_name in self.already_visited: continue input_node = self.nodes_map[input_node_name] self.eightbitize_nodes_recursively(input_node) - self.already_visited[current_node.name] = True if current_node.op == "MatMul": self.eightbitize_mat_mul_node(current_node) elif current_node.op == "Conv2D": |