aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-01 14:10:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 15:18:15 -0700
commit95f7166b8860f568f056a6c20ff626f6a7f069fc (patch)
tree4c287e4b2e82927a87729b65edf9c8693f015a45
parent1dadfdd27650e21e0c679e615ddd377f380c574a (diff)
Change quantize_graph in eightbit mode to remove FakeQuantWithMinMaxVars
nodes and use the information provided by them to set the min/max values on quantization-related nodes. In eightbit mode, also changed how constant weights are quantized - instead of doing it as a step after the main recursion, do it during the main recursion. This allows the float inputs to FakeQuantWithMinMaxVars to be excluded from quantization. In eightbit mode, maintain more state in the stack during recursion. Also change quantize reshape registration to register always and not use TF_CALL_xyz; this matches other quantized ops. Change: 137877226
-rw-r--r--tensorflow/core/kernels/quantized_reshape_op.cc4
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py118
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py75
3 files changed, 175 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/quantized_reshape_op.cc b/tensorflow/core/kernels/quantized_reshape_op.cc
index d49edd3feb..bd76c94ede 100644
--- a/tensorflow/core/kernels/quantized_reshape_op.cc
+++ b/tensorflow/core/kernels/quantized_reshape_op.cc
@@ -50,8 +50,8 @@ class QuantizedReshapeOp : public ReshapeOp {
.TypeConstraint<type>("T"), \
QuantizedReshapeOp)
-TF_CALL_quint8(REGISTER_CPU_KERNEL);
-TF_CALL_qint32(REGISTER_CPU_KERNEL);
+REGISTER_CPU_KERNEL(::tensorflow::quint8);
+REGISTER_CPU_KERNEL(::tensorflow::qint32);
#undef REGISTER_CPU_KERNEL
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
index 894806f186..d11fbf65ed 100644
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ b/tensorflow/tools/quantization/quantize_graph.py
@@ -28,6 +28,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import re
import numpy as np
import tensorflow as tf
@@ -283,6 +284,11 @@ def quantize_weight_eightbit(input_node, quantization_mode):
return [quint8_const_node, min_node, max_node, dequantize_node]
+EightbitizeRecursionState = collections.namedtuple(
+ "EightbitizeRecursionState", ["already_visited", "output_node_stack",
+ "merged_with_fake_quant"])
+
+
class GraphRewriter(object):
"""Takes a float graph, and rewrites it in quantized form."""
@@ -316,6 +322,9 @@ class GraphRewriter(object):
else:
self.input_range = None
+ # Data that is valid only during the recursive call to rewrite the graph.
+ self.state = None
+
def create_nodes_map(self, graph):
"""Builds a mapping of node names to their defs from the graph."""
nodes_map = {}
@@ -353,11 +362,12 @@ class GraphRewriter(object):
output_nodes = [self.nodes_map[output_node_name]
for output_node_name in output_node_names]
- self.already_visited = {}
- self.layers_eightbitized = []
+ self.state = EightbitizeRecursionState(already_visited={},
+ output_node_stack=[],
+ merged_with_fake_quant={})
for output_node in output_nodes:
self.eightbitize_nodes_recursively(output_node)
- self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST")
+ self.state = None
if self.input_range:
self.add_output_graph_node(create_constant_node(
"quantized_input_min_value", self.input_range[0], tf.float32, []))
@@ -477,20 +487,54 @@ class GraphRewriter(object):
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)
+ def should_merge_with_fake_quant_node(self):
+ """Should the current node merge with self.state.output_node_stack[-1]?"""
+ if not self.state.output_node_stack: return False
+ top = self.state.output_node_stack[-1]
+ return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
+
+ def should_quantize_const(self, node):
+ if not self.state.output_node_stack: return False
+ top = self.state.output_node_stack[-1]
+ if not top[2]: return False
+ assert tf.as_dtype(node.attr["dtype"].type) == tf.float32, (
+ "Quantizing constant %s" % node.name)
+ return True
+
def eightbitize_nodes_recursively(self, current_node):
"""The entry point for transforming a graph into full eight bit."""
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
+ if current_node.name in self.state.already_visited:
+ if (self.should_merge_with_fake_quant_node() or
+ current_node.name in self.state.merged_with_fake_quant):
+ raise ValueError("Unsupported graph structure: output of node %s "
+ "is processed by a FakeQuant* node and should have "
+ "no other outputs.", current_node.name)
+ return
+ self.state.already_visited[current_node.name] = True
+
+ for i, input_node_name in enumerate(current_node.input):
+ quantize_input = False
+ if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
+ "AvgPool", "Relu", "Relu6",
+ "BatchNormWithGlobalNormalization"):
+ quantize_input = True
+ elif current_node.op == "Concat" and i > 0:
+ quantize_input = True
+ elif current_node.op == "Reshape" and i == 0:
+ quantize_input = True
+
+ self.state.output_node_stack.append((current_node, i, quantize_input))
+
input_node_name = node_name_from_input(input_node_name)
- if input_node_name in self.already_visited:
- continue
input_node = self.nodes_map[input_node_name]
self.eightbitize_nodes_recursively(input_node)
+
+ self.state.output_node_stack.pop()
+
if current_node.op == "MatMul":
self.eightbitize_mat_mul_node(current_node)
elif current_node.op == "Conv2D":
self.eightbitize_conv_node(current_node)
- self.layers_eightbitized.append(current_node.name)
elif current_node.op == "BiasAdd":
self.eightbitize_bias_add_node(current_node)
elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
@@ -508,11 +552,29 @@ class GraphRewriter(object):
elif (self.input_range and
current_node.op in ("Placeholder", "PlaceholderV2")):
self.eightbitize_placeholder_node(current_node)
+ elif current_node.op == "FakeQuantWithMinMaxVars":
+ # It will have been merged into the underlying node.
+ pass
+ elif current_node.op == "Const":
+ if self.should_quantize_const(current_node):
+ for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
+ self.add_output_graph_node(n)
+ else:
+ new_node = tf.NodeDef()
+ new_node.CopyFrom(current_node)
+ self.add_output_graph_node(new_node)
else:
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
self.add_output_graph_node(new_node)
+ if (self.should_merge_with_fake_quant_node() and
+ current_node.name not in self.state.merged_with_fake_quant):
+ raise ValueError(
+ "FakeQuant* node %s failed to merge with node %s of type %s" % (
+ self.state.output_node_stack[-1][0], current_node.name,
+ current_node.op))
+
def add_eightbit_prologue_nodes(self, original_node):
"""Adds input conversion nodes to handle quantizing the underlying node."""
namespace_prefix = original_node.name + "_eightbit"
@@ -583,16 +645,26 @@ class GraphRewriter(object):
quantized_output_name, quantized_output_name + ":1",
quantized_output_name + ":2"
]
- requant_range_node = create_node(
- "RequantizationRange", original_node.name + "_eightbit_requant_range",
- quantized_outputs)
- set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
- self.add_output_graph_node(requant_range_node)
-
+ min_max_inputs = None
+ if self.should_merge_with_fake_quant_node():
+ # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
+ # Requantize.
+ fake_quant_node = self.state.output_node_stack[-1][0]
+ min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
+ assert original_node.name not in self.state.merged_with_fake_quant
+ self.state.merged_with_fake_quant[original_node.name] = True
+ else:
+ # Add a RequantizationRange node for finding the min and max values.
+ requant_range_node = create_node(
+ "RequantizationRange", original_node.name + "_eightbit_requant_range",
+ quantized_outputs)
+ set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
+ self.add_output_graph_node(requant_range_node)
+ min_max_inputs = [requant_range_node.name + ":0",
+ requant_range_node.name + ":1"]
requantize_node = create_node(
"Requantize", original_node.name + "_eightbit_requantize",
- (quantized_outputs +
- [requant_range_node.name + ":0", requant_range_node.name + ":1"]))
+ quantized_outputs + min_max_inputs)
set_attr_dtype(requantize_node, "Tinput", tf.qint32)
set_attr_dtype(requantize_node, "out_type", tf.quint8)
self.add_output_graph_node(requantize_node)
@@ -600,12 +672,20 @@ class GraphRewriter(object):
def add_dequantize_result_node(self, quantized_output_name,
original_node_name, min_tensor_index=1):
+ min_max_inputs = [
+ "%s:%s" % (quantized_output_name, min_tensor_index),
+ "%s:%s" % (quantized_output_name, (min_tensor_index + 1))]
dequantize_name = original_node_name
+ if self.should_merge_with_fake_quant_node():
+ fake_quant_node = self.state.output_node_stack[-1][0]
+ if original_node_name not in self.state.merged_with_fake_quant:
+ min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
+ self.state.merged_with_fake_quant[original_node_name] = True
+ dequantize_name = fake_quant_node.name
+
dequantize_node = create_node(
"Dequantize", dequantize_name,
- [quantized_output_name,
- "%s:%s" % (quantized_output_name, min_tensor_index),
- "%s:%s" % (quantized_output_name, (min_tensor_index + 1))])
+ [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
set_attr_dtype(dequantize_node, "T", tf.quint8)
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
index 1521240f28..42f06ef2ed 100644
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ b/tensorflow/tools/quantization/quantize_graph_test.py
@@ -160,7 +160,7 @@ def get_top_value(input_values):
return max_index, max_value
-def test_graph(float_graph_def, input_map, output_names):
+def test_graph(float_graph_def, input_map, output_names, log_graph=False):
"""Runs the float graph through the rewriter and tests the results."""
float_results = run_graph_def(float_graph_def, input_map,
[output_name + ":0"
@@ -184,6 +184,9 @@ def test_graph(float_graph_def, input_map, output_names):
for expected, result in zip(float_results, eightbit_results):
assert are_tensors_near(expected, result, 1.0)
+ if log_graph:
+ tf.logging.info("8bit:\n%s", str(eightbit_graph_def))
+
# Test the weights_rounded mode. This uses the default bit_depth.
weights_rounded_rewriter = quantize_graph.GraphRewriter(
float_graph_def, "weights_rounded", quantized_input_range=None)
@@ -580,6 +583,40 @@ class QuantizeGraphTest(tf.test.TestCase):
float_graph_def.node.extend([relu_node])
test_graph(float_graph_def, {}, [relu_name])
+ def test_relu_w_fake_quant_w_min_max_vars(self):
+ input_node = quantize_graph.create_constant_node(
+ "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+ dtype=tf.float32, shape=[1, 2, 6, 1])
+ relu_node = quantize_graph.create_node("Relu", "relu",
+ [input_node.name])
+ quantize_graph.set_attr_dtype(relu_node, "T", tf.float32)
+
+ min_node = quantize_graph.create_constant_node(
+ "min_bias_add", value=0, dtype=tf.float32, shape=[])
+ max_node = quantize_graph.create_constant_node(
+ "max_bias_add", value=12, dtype=tf.float32, shape=[])
+ fake_quant_node = quantize_graph.create_node(
+ "FakeQuantWithMinMaxVars", "fake_quant",
+ [relu_node.name, min_node.name, max_node.name])
+
+ float_graph_def = tf.GraphDef()
+ float_graph_def.node.extend([input_node, relu_node, min_node, max_node,
+ fake_quant_node])
+ test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
+
+ # Verify there is only one Quantize and one Requantize op.
+ eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
+ "eightbit",
+ quantized_input_range=None)
+ eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
+
+ ops = [node.op for node in eightbit_graph_def.node]
+ # No quantize since all inputs are const and can be quantized up-front.
+ self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
+
+ # One dequantize at the end.
+ self.assertEqual(1, ops.count("Dequantize"))
+
def test_relu6(self):
input_constant_name = "input_constant"
relu6_name = "relu6"
@@ -720,6 +757,42 @@ class QuantizeGraphTest(tf.test.TestCase):
ops.count("QuantizeV2") + ops.count("Quantize"))
self.assertEqual(len(output_names), ops.count("Dequantize"))
+ def test_bias_add_w_fake_quant_w_min_max_vars(self):
+ input_node = quantize_graph.create_constant_node(
+ "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ dtype=tf.float32, shape=[1, 1, 2, 5])
+ offset_node = quantize_graph.create_constant_node(
+ "offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5])
+ bias_add_node = quantize_graph.create_node(
+ "BiasAdd", "bias_add", [input_node.name, offset_node.name])
+ quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
+
+ min_node = quantize_graph.create_constant_node(
+ "min_bias_add", value=-.5, dtype=tf.float32, shape=[])
+ max_node = quantize_graph.create_constant_node(
+ "max_bias_add", value=15.5, dtype=tf.float32, shape=[])
+ fake_quant_node = quantize_graph.create_node(
+ "FakeQuantWithMinMaxVars", "fake_quant",
+ [bias_add_node.name, min_node.name, max_node.name])
+
+ float_graph_def = tf.GraphDef()
+ float_graph_def.node.extend([input_node, offset_node, bias_add_node,
+ min_node, max_node, fake_quant_node])
+ test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
+
+ # Verify there is only one Quantize and one Requantize op.
+ eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
+ "eightbit",
+ quantized_input_range=None)
+ eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
+
+ ops = [node.op for node in eightbit_graph_def.node]
+ # No quantize since all inputs are const and can be quantized up-front.
+ self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
+
+ # One dequantize at the end.
+ self.assertEqual(1, ops.count("Dequantize"))
+
def test_remove_redundant_quantization(self):
a_constant_name = "a_constant"
a_constant_min_name = "a_constant_min"