aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/quantized_reshape_op.cc4
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py118
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py75
3 files changed, 175 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/quantized_reshape_op.cc b/tensorflow/core/kernels/quantized_reshape_op.cc
index d49edd3feb..bd76c94ede 100644
--- a/tensorflow/core/kernels/quantized_reshape_op.cc
+++ b/tensorflow/core/kernels/quantized_reshape_op.cc
@@ -50,8 +50,8 @@ class QuantizedReshapeOp : public ReshapeOp {
.TypeConstraint<type>("T"), \
QuantizedReshapeOp)
-TF_CALL_quint8(REGISTER_CPU_KERNEL);
-TF_CALL_qint32(REGISTER_CPU_KERNEL);
+REGISTER_CPU_KERNEL(::tensorflow::quint8);
+REGISTER_CPU_KERNEL(::tensorflow::qint32);
#undef REGISTER_CPU_KERNEL
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
index 894806f186..d11fbf65ed 100644
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ b/tensorflow/tools/quantization/quantize_graph.py
@@ -28,6 +28,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import re
import numpy as np
import tensorflow as tf
@@ -283,6 +284,11 @@ def quantize_weight_eightbit(input_node, quantization_mode):
return [quint8_const_node, min_node, max_node, dequantize_node]
+EightbitizeRecursionState = collections.namedtuple(
+ "EightbitizeRecursionState", ["already_visited", "output_node_stack",
+ "merged_with_fake_quant"])
+
+
class GraphRewriter(object):
"""Takes a float graph, and rewrites it in quantized form."""
@@ -316,6 +322,9 @@ class GraphRewriter(object):
else:
self.input_range = None
+ # Data that is valid only during the recursive call to rewrite the graph.
+ self.state = None
+
def create_nodes_map(self, graph):
"""Builds a mapping of node names to their defs from the graph."""
nodes_map = {}
@@ -353,11 +362,12 @@ class GraphRewriter(object):
output_nodes = [self.nodes_map[output_node_name]
for output_node_name in output_node_names]
- self.already_visited = {}
- self.layers_eightbitized = []
+ self.state = EightbitizeRecursionState(already_visited={},
+ output_node_stack=[],
+ merged_with_fake_quant={})
for output_node in output_nodes:
self.eightbitize_nodes_recursively(output_node)
- self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST")
+ self.state = None
if self.input_range:
self.add_output_graph_node(create_constant_node(
"quantized_input_min_value", self.input_range[0], tf.float32, []))
@@ -477,20 +487,54 @@ class GraphRewriter(object):
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)
+ def should_merge_with_fake_quant_node(self):
+ """Should the current node merge with self.state.output_node_stack[-1]?"""
+ if not self.state.output_node_stack: return False
+ top = self.state.output_node_stack[-1]
+ return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
+
+ def should_quantize_const(self, node):
+ if not self.state.output_node_stack: return False
+ top = self.state.output_node_stack[-1]
+ if not top[2]: return False
+ assert tf.as_dtype(node.attr["dtype"].type) == tf.float32, (
+ "Quantizing constant %s" % node.name)
+ return True
+
def eightbitize_nodes_recursively(self, current_node):
"""The entry point for transforming a graph into full eight bit."""
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
+ if current_node.name in self.state.already_visited:
+ if (self.should_merge_with_fake_quant_node() or
+ current_node.name in self.state.merged_with_fake_quant):
+ raise ValueError("Unsupported graph structure: output of node %s "
+ "is processed by a FakeQuant* node and should have "
+ "no other outputs.", current_node.name)
+ return
+ self.state.already_visited[current_node.name] = True
+
+ for i, input_node_name in enumerate(current_node.input):
+ quantize_input = False
+ if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
+ "AvgPool", "Relu", "Relu6",
+ "BatchNormWithGlobalNormalization"):
+ quantize_input = True
+ elif current_node.op == "Concat" and i > 0:
+ quantize_input = True
+ elif current_node.op == "Reshape" and i == 0:
+ quantize_input = True
+
+ self.state.output_node_stack.append((current_node, i, quantize_input))
+
input_node_name = node_name_from_input(input_node_name)
- if input_node_name in self.already_visited:
- continue
input_node = self.nodes_map[input_node_name]
self.eightbitize_nodes_recursively(input_node)
+
+ self.state.output_node_stack.pop()
+
if current_node.op == "MatMul":
self.eightbitize_mat_mul_node(current_node)
elif current_node.op == "Conv2D":
self.eightbitize_conv_node(current_node)
- self.layers_eightbitized.append(current_node.name)
elif current_node.op == "BiasAdd":
self.eightbitize_bias_add_node(current_node)
elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
@@ -508,11 +552,29 @@ class GraphRewriter(object):
elif (self.input_range and
current_node.op in ("Placeholder", "PlaceholderV2")):
self.eightbitize_placeholder_node(current_node)
+ elif current_node.op == "FakeQuantWithMinMaxVars":
+ # It will have been merged into the underlying node.
+ pass
+ elif current_node.op == "Const":
+ if self.should_quantize_const(current_node):
+ for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
+ self.add_output_graph_node(n)
+ else:
+ new_node = tf.NodeDef()
+ new_node.CopyFrom(current_node)
+ self.add_output_graph_node(new_node)
else:
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
self.add_output_graph_node(new_node)
+ if (self.should_merge_with_fake_quant_node() and
+ current_node.name not in self.state.merged_with_fake_quant):
+ raise ValueError(
+ "FakeQuant* node %s failed to merge with node %s of type %s" % (
+ self.state.output_node_stack[-1][0], current_node.name,
+ current_node.op))
+
def add_eightbit_prologue_nodes(self, original_node):
"""Adds input conversion nodes to handle quantizing the underlying node."""
namespace_prefix = original_node.name + "_eightbit"
@@ -583,16 +645,26 @@ class GraphRewriter(object):
quantized_output_name, quantized_output_name + ":1",
quantized_output_name + ":2"
]
- requant_range_node = create_node(
- "RequantizationRange", original_node.name + "_eightbit_requant_range",
- quantized_outputs)
- set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
- self.add_output_graph_node(requant_range_node)
-
+ min_max_inputs = None
+ if self.should_merge_with_fake_quant_node():
+ # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
+ # Requantize.
+ fake_quant_node = self.state.output_node_stack[-1][0]
+ min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
+ assert original_node.name not in self.state.merged_with_fake_quant
+ self.state.merged_with_fake_quant[original_node.name] = True
+ else:
+ # Add a RequantizationRange node for finding the min and max values.
+ requant_range_node = create_node(
+ "RequantizationRange", original_node.name + "_eightbit_requant_range",
+ quantized_outputs)
+ set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
+ self.add_output_graph_node(requant_range_node)
+ min_max_inputs = [requant_range_node.name + ":0",
+ requant_range_node.name + ":1"]
requantize_node = create_node(
"Requantize", original_node.name + "_eightbit_requantize",
- (quantized_outputs +
- [requant_range_node.name + ":0", requant_range_node.name + ":1"]))
+ quantized_outputs + min_max_inputs)
set_attr_dtype(requantize_node, "Tinput", tf.qint32)
set_attr_dtype(requantize_node, "out_type", tf.quint8)
self.add_output_graph_node(requantize_node)
@@ -600,12 +672,20 @@ class GraphRewriter(object):
def add_dequantize_result_node(self, quantized_output_name,
original_node_name, min_tensor_index=1):
+ min_max_inputs = [
+ "%s:%s" % (quantized_output_name, min_tensor_index),
+ "%s:%s" % (quantized_output_name, (min_tensor_index + 1))]
dequantize_name = original_node_name
+ if self.should_merge_with_fake_quant_node():
+ fake_quant_node = self.state.output_node_stack[-1][0]
+ if original_node_name not in self.state.merged_with_fake_quant:
+ min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
+ self.state.merged_with_fake_quant[original_node_name] = True
+ dequantize_name = fake_quant_node.name
+
dequantize_node = create_node(
"Dequantize", dequantize_name,
- [quantized_output_name,
- "%s:%s" % (quantized_output_name, min_tensor_index),
- "%s:%s" % (quantized_output_name, (min_tensor_index + 1))])
+ [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
set_attr_dtype(dequantize_node, "T", tf.quint8)
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
index 1521240f28..42f06ef2ed 100644
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ b/tensorflow/tools/quantization/quantize_graph_test.py
@@ -160,7 +160,7 @@ def get_top_value(input_values):
return max_index, max_value
-def test_graph(float_graph_def, input_map, output_names):
+def test_graph(float_graph_def, input_map, output_names, log_graph=False):
"""Runs the float graph through the rewriter and tests the results."""
float_results = run_graph_def(float_graph_def, input_map,
[output_name + ":0"
@@ -184,6 +184,9 @@ def test_graph(float_graph_def, input_map, output_names):
for expected, result in zip(float_results, eightbit_results):
assert are_tensors_near(expected, result, 1.0)
+ if log_graph:
+ tf.logging.info("8bit:\n%s", str(eightbit_graph_def))
+
# Test the weights_rounded mode. This uses the default bit_depth.
weights_rounded_rewriter = quantize_graph.GraphRewriter(
float_graph_def, "weights_rounded", quantized_input_range=None)
@@ -580,6 +583,40 @@ class QuantizeGraphTest(tf.test.TestCase):
float_graph_def.node.extend([relu_node])
test_graph(float_graph_def, {}, [relu_name])
+ def test_relu_w_fake_quant_w_min_max_vars(self):
+ input_node = quantize_graph.create_constant_node(
+ "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+ dtype=tf.float32, shape=[1, 2, 6, 1])
+ relu_node = quantize_graph.create_node("Relu", "relu",
+ [input_node.name])
+ quantize_graph.set_attr_dtype(relu_node, "T", tf.float32)
+
+ min_node = quantize_graph.create_constant_node(
+ "min_bias_add", value=0, dtype=tf.float32, shape=[])
+ max_node = quantize_graph.create_constant_node(
+ "max_bias_add", value=12, dtype=tf.float32, shape=[])
+ fake_quant_node = quantize_graph.create_node(
+ "FakeQuantWithMinMaxVars", "fake_quant",
+ [relu_node.name, min_node.name, max_node.name])
+
+ float_graph_def = tf.GraphDef()
+ float_graph_def.node.extend([input_node, relu_node, min_node, max_node,
+ fake_quant_node])
+ test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
+
+ # Verify there is only one Quantize and one Requantize op.
+ eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
+ "eightbit",
+ quantized_input_range=None)
+ eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
+
+ ops = [node.op for node in eightbit_graph_def.node]
+ # No quantize since all inputs are const and can be quantized up-front.
+ self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
+
+ # One dequantize at the end.
+ self.assertEqual(1, ops.count("Dequantize"))
+
def test_relu6(self):
input_constant_name = "input_constant"
relu6_name = "relu6"
@@ -720,6 +757,42 @@ class QuantizeGraphTest(tf.test.TestCase):
ops.count("QuantizeV2") + ops.count("Quantize"))
self.assertEqual(len(output_names), ops.count("Dequantize"))
+ def test_bias_add_w_fake_quant_w_min_max_vars(self):
+ input_node = quantize_graph.create_constant_node(
+ "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ dtype=tf.float32, shape=[1, 1, 2, 5])
+ offset_node = quantize_graph.create_constant_node(
+ "offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5])
+ bias_add_node = quantize_graph.create_node(
+ "BiasAdd", "bias_add", [input_node.name, offset_node.name])
+ quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
+
+ min_node = quantize_graph.create_constant_node(
+ "min_bias_add", value=-.5, dtype=tf.float32, shape=[])
+ max_node = quantize_graph.create_constant_node(
+ "max_bias_add", value=15.5, dtype=tf.float32, shape=[])
+ fake_quant_node = quantize_graph.create_node(
+ "FakeQuantWithMinMaxVars", "fake_quant",
+ [bias_add_node.name, min_node.name, max_node.name])
+
+ float_graph_def = tf.GraphDef()
+ float_graph_def.node.extend([input_node, offset_node, bias_add_node,
+ min_node, max_node, fake_quant_node])
+ test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
+
+ # Verify there is only one Quantize and one Requantize op.
+ eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
+ "eightbit",
+ quantized_input_range=None)
+ eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
+
+ ops = [node.op for node in eightbit_graph_def.node]
+ # No quantize since all inputs are const and can be quantized up-front.
+ self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
+
+ # One dequantize at the end.
+ self.assertEqual(1, ops.count("Dequantize"))
+
def test_remove_redundant_quantization(self):
a_constant_name = "a_constant"
a_constant_min_name = "a_constant_min"