diff options
-rw-r--r-- | tensorflow/core/kernels/quantized_reshape_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/tools/quantization/quantize_graph.py | 118 | ||||
-rw-r--r-- | tensorflow/tools/quantization/quantize_graph_test.py | 75 |
3 files changed, 175 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/quantized_reshape_op.cc b/tensorflow/core/kernels/quantized_reshape_op.cc index d49edd3feb..bd76c94ede 100644 --- a/tensorflow/core/kernels/quantized_reshape_op.cc +++ b/tensorflow/core/kernels/quantized_reshape_op.cc @@ -50,8 +50,8 @@ class QuantizedReshapeOp : public ReshapeOp { .TypeConstraint<type>("T"), \ QuantizedReshapeOp) -TF_CALL_quint8(REGISTER_CPU_KERNEL); -TF_CALL_qint32(REGISTER_CPU_KERNEL); +REGISTER_CPU_KERNEL(::tensorflow::quint8); +REGISTER_CPU_KERNEL(::tensorflow::qint32); #undef REGISTER_CPU_KERNEL diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py index 894806f186..d11fbf65ed 100644 --- a/tensorflow/tools/quantization/quantize_graph.py +++ b/tensorflow/tools/quantization/quantize_graph.py @@ -28,6 +28,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import re import numpy as np import tensorflow as tf @@ -283,6 +284,11 @@ def quantize_weight_eightbit(input_node, quantization_mode): return [quint8_const_node, min_node, max_node, dequantize_node] +EightbitizeRecursionState = collections.namedtuple( + "EightbitizeRecursionState", ["already_visited", "output_node_stack", + "merged_with_fake_quant"]) + + class GraphRewriter(object): """Takes a float graph, and rewrites it in quantized form.""" @@ -316,6 +322,9 @@ class GraphRewriter(object): else: self.input_range = None + # Data that is valid only during the recursive call to rewrite the graph. + self.state = None + def create_nodes_map(self, graph): """Builds a mapping of node names to their defs from the graph.""" nodes_map = {} @@ -353,11 +362,12 @@ class GraphRewriter(object): output_nodes = [self.nodes_map[output_node_name] for output_node_name in output_node_names] - self.already_visited = {} - self.layers_eightbitized = [] + self.state = EightbitizeRecursionState(already_visited={}, + output_node_stack=[], + merged_with_fake_quant={}) for output_node in output_nodes: self.eightbitize_nodes_recursively(output_node) - self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST") + self.state = None if self.input_range: self.add_output_graph_node(create_constant_node( "quantized_input_min_value", self.input_range[0], tf.float32, [])) @@ -477,20 +487,54 @@ class GraphRewriter(object): set_attr_string(dequantize_node, "mode", b"MIN_FIRST") self.add_output_graph_node(dequantize_node) + def should_merge_with_fake_quant_node(self): + """Should the current node merge with self.state.output_node_stack[-1]?""" + if not self.state.output_node_stack: return False + top = self.state.output_node_stack[-1] + return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"] + + def should_quantize_const(self, node): + if not self.state.output_node_stack: return False + top = self.state.output_node_stack[-1] + if not top[2]: return False + assert tf.as_dtype(node.attr["dtype"].type) == tf.float32, ( + "Quantizing constant %s" % node.name) + return True + def eightbitize_nodes_recursively(self, current_node): """The entry point for transforming a graph into full eight bit.""" - self.already_visited[current_node.name] = True - for input_node_name in current_node.input: + if current_node.name in self.state.already_visited: + if (self.should_merge_with_fake_quant_node() or + current_node.name in self.state.merged_with_fake_quant): + raise ValueError("Unsupported graph structure: output of node %s " + "is processed by a FakeQuant* node and should have " + "no other outputs.", current_node.name) + return + self.state.already_visited[current_node.name] = True + + for i, input_node_name in enumerate(current_node.input): + quantize_input = False + if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool", + "AvgPool", "Relu", "Relu6", + "BatchNormWithGlobalNormalization"): + quantize_input = True + elif current_node.op == "Concat" and i > 0: + quantize_input = True + elif current_node.op == "Reshape" and i == 0: + quantize_input = True + + self.state.output_node_stack.append((current_node, i, quantize_input)) + input_node_name = node_name_from_input(input_node_name) - if input_node_name in self.already_visited: - continue input_node = self.nodes_map[input_node_name] self.eightbitize_nodes_recursively(input_node) + + self.state.output_node_stack.pop() + if current_node.op == "MatMul": self.eightbitize_mat_mul_node(current_node) elif current_node.op == "Conv2D": self.eightbitize_conv_node(current_node) - self.layers_eightbitized.append(current_node.name) elif current_node.op == "BiasAdd": self.eightbitize_bias_add_node(current_node) elif current_node.op == "MaxPool" or current_node.op == "AvgPool": @@ -508,11 +552,29 @@ class GraphRewriter(object): elif (self.input_range and current_node.op in ("Placeholder", "PlaceholderV2")): self.eightbitize_placeholder_node(current_node) + elif current_node.op == "FakeQuantWithMinMaxVars": + # It will have been merged into the underlying node. + pass + elif current_node.op == "Const": + if self.should_quantize_const(current_node): + for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"): + self.add_output_graph_node(n) + else: + new_node = tf.NodeDef() + new_node.CopyFrom(current_node) + self.add_output_graph_node(new_node) else: new_node = tf.NodeDef() new_node.CopyFrom(current_node) self.add_output_graph_node(new_node) + if (self.should_merge_with_fake_quant_node() and + current_node.name not in self.state.merged_with_fake_quant): + raise ValueError( + "FakeQuant* node %s failed to merge with node %s of type %s" % ( + self.state.output_node_stack[-1][0], current_node.name, + current_node.op)) + def add_eightbit_prologue_nodes(self, original_node): """Adds input conversion nodes to handle quantizing the underlying node.""" namespace_prefix = original_node.name + "_eightbit" @@ -583,16 +645,26 @@ class GraphRewriter(object): quantized_output_name, quantized_output_name + ":1", quantized_output_name + ":2" ] - requant_range_node = create_node( - "RequantizationRange", original_node.name + "_eightbit_requant_range", - quantized_outputs) - set_attr_dtype(requant_range_node, "Tinput", tf.qint32) - self.add_output_graph_node(requant_range_node) - + min_max_inputs = None + if self.should_merge_with_fake_quant_node(): + # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to + # Requantize. + fake_quant_node = self.state.output_node_stack[-1][0] + min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]] + assert original_node.name not in self.state.merged_with_fake_quant + self.state.merged_with_fake_quant[original_node.name] = True + else: + # Add a RequantizationRange node for finding the min and max values. + requant_range_node = create_node( + "RequantizationRange", original_node.name + "_eightbit_requant_range", + quantized_outputs) + set_attr_dtype(requant_range_node, "Tinput", tf.qint32) + self.add_output_graph_node(requant_range_node) + min_max_inputs = [requant_range_node.name + ":0", + requant_range_node.name + ":1"] requantize_node = create_node( "Requantize", original_node.name + "_eightbit_requantize", - (quantized_outputs + - [requant_range_node.name + ":0", requant_range_node.name + ":1"])) + quantized_outputs + min_max_inputs) set_attr_dtype(requantize_node, "Tinput", tf.qint32) set_attr_dtype(requantize_node, "out_type", tf.quint8) self.add_output_graph_node(requantize_node) @@ -600,12 +672,20 @@ class GraphRewriter(object): def add_dequantize_result_node(self, quantized_output_name, original_node_name, min_tensor_index=1): + min_max_inputs = [ + "%s:%s" % (quantized_output_name, min_tensor_index), + "%s:%s" % (quantized_output_name, (min_tensor_index + 1))] dequantize_name = original_node_name + if self.should_merge_with_fake_quant_node(): + fake_quant_node = self.state.output_node_stack[-1][0] + if original_node_name not in self.state.merged_with_fake_quant: + min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]] + self.state.merged_with_fake_quant[original_node_name] = True + dequantize_name = fake_quant_node.name + dequantize_node = create_node( "Dequantize", dequantize_name, - [quantized_output_name, - "%s:%s" % (quantized_output_name, min_tensor_index), - "%s:%s" % (quantized_output_name, (min_tensor_index + 1))]) + [quantized_output_name, min_max_inputs[0], min_max_inputs[1]]) set_attr_dtype(dequantize_node, "T", tf.quint8) set_attr_string(dequantize_node, "mode", b"MIN_FIRST") self.add_output_graph_node(dequantize_node) diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py index 1521240f28..42f06ef2ed 100644 --- a/tensorflow/tools/quantization/quantize_graph_test.py +++ b/tensorflow/tools/quantization/quantize_graph_test.py @@ -160,7 +160,7 @@ def get_top_value(input_values): return max_index, max_value -def test_graph(float_graph_def, input_map, output_names): +def test_graph(float_graph_def, input_map, output_names, log_graph=False): """Runs the float graph through the rewriter and tests the results.""" float_results = run_graph_def(float_graph_def, input_map, [output_name + ":0" @@ -184,6 +184,9 @@ def test_graph(float_graph_def, input_map, output_names): for expected, result in zip(float_results, eightbit_results): assert are_tensors_near(expected, result, 1.0) + if log_graph: + tf.logging.info("8bit:\n%s", str(eightbit_graph_def)) + # Test the weights_rounded mode. This uses the default bit_depth. weights_rounded_rewriter = quantize_graph.GraphRewriter( float_graph_def, "weights_rounded", quantized_input_range=None) @@ -580,6 +583,40 @@ class QuantizeGraphTest(tf.test.TestCase): float_graph_def.node.extend([relu_node]) test_graph(float_graph_def, {}, [relu_name]) + def test_relu_w_fake_quant_w_min_max_vars(self): + input_node = quantize_graph.create_constant_node( + "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + dtype=tf.float32, shape=[1, 2, 6, 1]) + relu_node = quantize_graph.create_node("Relu", "relu", + [input_node.name]) + quantize_graph.set_attr_dtype(relu_node, "T", tf.float32) + + min_node = quantize_graph.create_constant_node( + "min_bias_add", value=0, dtype=tf.float32, shape=[]) + max_node = quantize_graph.create_constant_node( + "max_bias_add", value=12, dtype=tf.float32, shape=[]) + fake_quant_node = quantize_graph.create_node( + "FakeQuantWithMinMaxVars", "fake_quant", + [relu_node.name, min_node.name, max_node.name]) + + float_graph_def = tf.GraphDef() + float_graph_def.node.extend([input_node, relu_node, min_node, max_node, + fake_quant_node]) + test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True) + + # Verify there is only one Quantize and one Requantize op. + eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def, + "eightbit", + quantized_input_range=None) + eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) + + ops = [node.op for node in eightbit_graph_def.node] + # No quantize since all inputs are const and can be quantized up-front. + self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) + + # One dequantize at the end. + self.assertEqual(1, ops.count("Dequantize")) + def test_relu6(self): input_constant_name = "input_constant" relu6_name = "relu6" @@ -720,6 +757,42 @@ class QuantizeGraphTest(tf.test.TestCase): ops.count("QuantizeV2") + ops.count("Quantize")) self.assertEqual(len(output_names), ops.count("Dequantize")) + def test_bias_add_w_fake_quant_w_min_max_vars(self): + input_node = quantize_graph.create_constant_node( + "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + dtype=tf.float32, shape=[1, 1, 2, 5]) + offset_node = quantize_graph.create_constant_node( + "offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5]) + bias_add_node = quantize_graph.create_node( + "BiasAdd", "bias_add", [input_node.name, offset_node.name]) + quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32) + + min_node = quantize_graph.create_constant_node( + "min_bias_add", value=-.5, dtype=tf.float32, shape=[]) + max_node = quantize_graph.create_constant_node( + "max_bias_add", value=15.5, dtype=tf.float32, shape=[]) + fake_quant_node = quantize_graph.create_node( + "FakeQuantWithMinMaxVars", "fake_quant", + [bias_add_node.name, min_node.name, max_node.name]) + + float_graph_def = tf.GraphDef() + float_graph_def.node.extend([input_node, offset_node, bias_add_node, + min_node, max_node, fake_quant_node]) + test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True) + + # Verify there is only one Quantize and one Requantize op. + eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def, + "eightbit", + quantized_input_range=None) + eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) + + ops = [node.op for node in eightbit_graph_def.node] + # No quantize since all inputs are const and can be quantized up-front. + self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) + + # One dequantize at the end. + self.assertEqual(1, ops.count("Dequantize")) + def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" |