aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/quantization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-08 09:52:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:21:34 -0800
commit42e9d54c833f6c16b9c864a0cdb2191fceb0e7dd (patch)
treed0a8647bcbe9babb1120dabc5d5309b923be0fdd /tensorflow/tools/quantization
parent811b43a56f0f1d05925cf054cc80d8e6f6490b7e (diff)
Add --quantized_fallback_min and --quantized_fallback_max for use
when experimenting with quantized graphs. Change: 138529416
Diffstat (limited to 'tensorflow/tools/quantization')
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py59
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py49
2 files changed, 103 insertions, 5 deletions
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
index d11fbf65ed..aa29dc2326 100644
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ b/tensorflow/tools/quantization/quantize_graph.py
@@ -66,6 +66,18 @@ flags.DEFINE_float("quantized_input_min", 0,
flags.DEFINE_float("quantized_input_max", 1,
"The maximum of the actual input range when "
"--quantized_input")
+flags.DEFINE_float(
+ "quantized_fallback_min", None,
+ "The fallback 'min' value to use for layers which lack min-max "
+ "information. Note: this should be considered a coarse tool just good "
+ "enough for experimentation purposes, since graphs quantized in this way "
+ "would be very inaccurate.")
+flags.DEFINE_float(
+ "quantized_fallback_max", None,
+ "The fallback 'max' value to use for layers which lack min-max "
+ "information. Note: this should be considered a coarse tool just good "
+ "enough for experimentation purposes, since graphs quantized in this way "
+ "would be very inaccurate.")
def print_input_nodes(current_node, nodes_map, indent, already_visited):
@@ -292,7 +304,8 @@ EightbitizeRecursionState = collections.namedtuple(
class GraphRewriter(object):
"""Takes a float graph, and rewrites it in quantized form."""
- def __init__(self, input_graph, mode, quantized_input_range):
+ def __init__(self, input_graph, mode, quantized_input_range,
+ fallback_quantization_range=None):
"""Sets up the class to rewrite a float graph.
Args:
@@ -302,6 +315,10 @@ class GraphRewriter(object):
quantized_input_range: if set, assume the input is
quantized and represents the range
[quantized_input_range[0], quantized_input_range[1]]
+ fallback_quantization_range: if set, then for nodes where the quantization
+ range can't be inferred from the graph, use the range
+ [fallback_quantization_range[0], fallback_quantization_range[1]) instead
+ of using a RequantizationRange node in the graph.
Raises:
ValueError: Two nodes with the same name were found in the graph.
@@ -322,6 +339,20 @@ class GraphRewriter(object):
else:
self.input_range = None
+ if fallback_quantization_range:
+ self.fallback_quantization_range = [fallback_quantization_range[0],
+ fallback_quantization_range[1]]
+ if (self.fallback_quantization_range[0] >=
+ self.fallback_quantization_range[1]):
+ raise ValueError("Invalid fallback_quantization_range: [%s,%s]" %
+ self.fallback_quantization_range)
+ if self.mode != "eightbit":
+ raise ValueError(
+ "fallback_quantization_range can only be "
+ "specified in eightbit mode")
+ else:
+ self.fallback_quantization_range = None
+
# Data that is valid only during the recursive call to rewrite the graph.
self.state = None
@@ -373,6 +404,13 @@ class GraphRewriter(object):
"quantized_input_min_value", self.input_range[0], tf.float32, []))
self.add_output_graph_node(create_constant_node(
"quantized_input_max_value", self.input_range[1], tf.float32, []))
+ if self.fallback_quantization_range:
+ self.add_output_graph_node(create_constant_node(
+ "fallback_quantization_min_value",
+ self.fallback_quantization_range[0], tf.float32, []))
+ self.add_output_graph_node(create_constant_node(
+ "fallback_quantization_max_value",
+ self.fallback_quantization_range[1], tf.float32, []))
if FLAGS.strip_redundant_quantization:
self.output_graph = self.remove_redundant_quantization(
self.output_graph)
@@ -563,6 +601,11 @@ class GraphRewriter(object):
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
self.add_output_graph_node(new_node)
+
+ ###################################################################
+ # Note: if more cases are added here, you may need to update the op
+ # name lists in the loop over children at the start of the function.
+ ###################################################################
else:
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
@@ -653,6 +696,9 @@ class GraphRewriter(object):
min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
assert original_node.name not in self.state.merged_with_fake_quant
self.state.merged_with_fake_quant[original_node.name] = True
+ elif self.fallback_quantization_range:
+ min_max_inputs = ["fallback_quantization_min_value:0",
+ "fallback_quantization_max_value:0"]
else:
# Add a RequantizationRange node for finding the min and max values.
requant_range_node = create_node(
@@ -1187,7 +1233,16 @@ def main(unused_args):
quantized_input_range = [FLAGS.quantized_input_min,
FLAGS.quantized_input_max]
- rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range)
+ fallback_quantization_range = None
+ if (FLAGS.quantized_fallback_min is not None or
+ FLAGS.quantized_fallback_max is not None):
+ assert FLAGS.quantized_fallback_min is not None
+ assert FLAGS.quantized_fallback_max is not None
+ fallback_quantization_range = [FLAGS.quantized_fallback_min,
+ FLAGS.quantized_fallback_max]
+
+ rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
+ fallback_quantization_range)
output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
index 42f06ef2ed..8c2938a28d 100644
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ b/tensorflow/tools/quantization/quantize_graph_test.py
@@ -781,18 +781,61 @@ class QuantizeGraphTest(tf.test.TestCase):
test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
# Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
- "eightbit",
- quantized_input_range=None)
+ # Pass in fallback_quantization_range, although it will have no effect
+ # because the FakeQuantWithMinMaxVars are used instead.
+ eightbit_rewriter = quantize_graph.GraphRewriter(
+ float_graph_def, "eightbit", quantized_input_range=None,
+ fallback_quantization_range=[-100, 100])
eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
ops = [node.op for node in eightbit_graph_def.node]
+ node_names = [node.name for node in eightbit_graph_def.node]
# No quantize since all inputs are const and can be quantized up-front.
self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
# One dequantize at the end.
self.assertEqual(1, ops.count("Dequantize"))
+ # The fallback constants are not in the graph.
+ self.assertEqual(0, node_names.count("fallback_quantization_min_value"))
+ self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
+
+ def test_bias_add_w_fallback_min_max_vars(self):
+ input_node = quantize_graph.create_constant_node(
+ "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ dtype=tf.float32, shape=[1, 1, 2, 5])
+ offset_node = quantize_graph.create_constant_node(
+ "offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5])
+ bias_add_node = quantize_graph.create_node(
+ "BiasAdd", "bias_add", [input_node.name, offset_node.name])
+ quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
+
+ float_graph_def = tf.GraphDef()
+ float_graph_def.node.extend([input_node, offset_node, bias_add_node])
+ test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True)
+
+ # Verify there is only one Quantize, one Requantize op, and no
+ # RequantizationRange op.
+ eightbit_rewriter = quantize_graph.GraphRewriter(
+ float_graph_def, "eightbit", quantized_input_range=None,
+ fallback_quantization_range=[-.5, 15.5])
+ eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name])
+
+ ops = [node.op for node in eightbit_graph_def.node]
+ node_names = [node.name for node in eightbit_graph_def.node]
+ # No quantize since all inputs are const and can be quantized up-front.
+ self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
+
+ # One dequantize at the end.
+ self.assertEqual(1, ops.count("Dequantize"))
+
+ # No RequantizationRange
+ self.assertEqual(0, ops.count("RequantizationRange"))
+
+ # The fallback constants are in the graph.
+ self.assertEqual(1, node_names.count("fallback_quantization_min_value"))
+ self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
+
def test_remove_redundant_quantization(self):
a_constant_name = "a_constant"
a_constant_min_name = "a_constant_min"