diff options
author | 2016-11-08 09:52:50 -0800 | |
---|---|---|
committer | 2016-11-08 16:21:34 -0800 | |
commit | 42e9d54c833f6c16b9c864a0cdb2191fceb0e7dd (patch) | |
tree | d0a8647bcbe9babb1120dabc5d5309b923be0fdd /tensorflow/tools/quantization | |
parent | 811b43a56f0f1d05925cf054cc80d8e6f6490b7e (diff) |
Add --quantized_fallback_min and --quantized_fallback_max for use
when experimenting with quantized graphs.
Change: 138529416
Diffstat (limited to 'tensorflow/tools/quantization')
-rw-r--r-- | tensorflow/tools/quantization/quantize_graph.py | 59 | ||||
-rw-r--r-- | tensorflow/tools/quantization/quantize_graph_test.py | 49 |
2 files changed, 103 insertions, 5 deletions
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py index d11fbf65ed..aa29dc2326 100644 --- a/tensorflow/tools/quantization/quantize_graph.py +++ b/tensorflow/tools/quantization/quantize_graph.py @@ -66,6 +66,18 @@ flags.DEFINE_float("quantized_input_min", 0, flags.DEFINE_float("quantized_input_max", 1, "The maximum of the actual input range when " "--quantized_input") +flags.DEFINE_float( + "quantized_fallback_min", None, + "The fallback 'min' value to use for layers which lack min-max " + "information. Note: this should be considered a coarse tool just good " + "enough for experimentation purposes, since graphs quantized in this way " + "would be very inaccurate.") +flags.DEFINE_float( + "quantized_fallback_max", None, + "The fallback 'max' value to use for layers which lack min-max " + "information. Note: this should be considered a coarse tool just good " + "enough for experimentation purposes, since graphs quantized in this way " + "would be very inaccurate.") def print_input_nodes(current_node, nodes_map, indent, already_visited): @@ -292,7 +304,8 @@ EightbitizeRecursionState = collections.namedtuple( class GraphRewriter(object): """Takes a float graph, and rewrites it in quantized form.""" - def __init__(self, input_graph, mode, quantized_input_range): + def __init__(self, input_graph, mode, quantized_input_range, + fallback_quantization_range=None): """Sets up the class to rewrite a float graph. Args: @@ -302,6 +315,10 @@ class GraphRewriter(object): quantized_input_range: if set, assume the input is quantized and represents the range [quantized_input_range[0], quantized_input_range[1]] + fallback_quantization_range: if set, then for nodes where the quantization + range can't be inferred from the graph, use the range + [fallback_quantization_range[0], fallback_quantization_range[1]) instead + of using a RequantizationRange node in the graph. Raises: ValueError: Two nodes with the same name were found in the graph. @@ -322,6 +339,20 @@ class GraphRewriter(object): else: self.input_range = None + if fallback_quantization_range: + self.fallback_quantization_range = [fallback_quantization_range[0], + fallback_quantization_range[1]] + if (self.fallback_quantization_range[0] >= + self.fallback_quantization_range[1]): + raise ValueError("Invalid fallback_quantization_range: [%s,%s]" % + self.fallback_quantization_range) + if self.mode != "eightbit": + raise ValueError( + "fallback_quantization_range can only be " + "specified in eightbit mode") + else: + self.fallback_quantization_range = None + # Data that is valid only during the recursive call to rewrite the graph. self.state = None @@ -373,6 +404,13 @@ class GraphRewriter(object): "quantized_input_min_value", self.input_range[0], tf.float32, [])) self.add_output_graph_node(create_constant_node( "quantized_input_max_value", self.input_range[1], tf.float32, [])) + if self.fallback_quantization_range: + self.add_output_graph_node(create_constant_node( + "fallback_quantization_min_value", + self.fallback_quantization_range[0], tf.float32, [])) + self.add_output_graph_node(create_constant_node( + "fallback_quantization_max_value", + self.fallback_quantization_range[1], tf.float32, [])) if FLAGS.strip_redundant_quantization: self.output_graph = self.remove_redundant_quantization( self.output_graph) @@ -563,6 +601,11 @@ class GraphRewriter(object): new_node = tf.NodeDef() new_node.CopyFrom(current_node) self.add_output_graph_node(new_node) + + ################################################################### + # Note: if more cases are added here, you may need to update the op + # name lists in the loop over children at the start of the function. + ################################################################### else: new_node = tf.NodeDef() new_node.CopyFrom(current_node) @@ -653,6 +696,9 @@ class GraphRewriter(object): min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]] assert original_node.name not in self.state.merged_with_fake_quant self.state.merged_with_fake_quant[original_node.name] = True + elif self.fallback_quantization_range: + min_max_inputs = ["fallback_quantization_min_value:0", + "fallback_quantization_max_value:0"] else: # Add a RequantizationRange node for finding the min and max values. requant_range_node = create_node( @@ -1187,7 +1233,16 @@ def main(unused_args): quantized_input_range = [FLAGS.quantized_input_min, FLAGS.quantized_input_max] - rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range) + fallback_quantization_range = None + if (FLAGS.quantized_fallback_min is not None or + FLAGS.quantized_fallback_max is not None): + assert FLAGS.quantized_fallback_min is not None + assert FLAGS.quantized_fallback_max is not None + fallback_quantization_range = [FLAGS.quantized_fallback_min, + FLAGS.quantized_fallback_max] + + rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range, + fallback_quantization_range) output_graph = rewriter.rewrite(FLAGS.output_node_names.split(",")) diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py index 42f06ef2ed..8c2938a28d 100644 --- a/tensorflow/tools/quantization/quantize_graph_test.py +++ b/tensorflow/tools/quantization/quantize_graph_test.py @@ -781,18 +781,61 @@ class QuantizeGraphTest(tf.test.TestCase): test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True) # Verify there is only one Quantize and one Requantize op. - eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def, - "eightbit", - quantized_input_range=None) + # Pass in fallback_quantization_range, although it will have no effect + # because the FakeQuantWithMinMaxVars are used instead. + eightbit_rewriter = quantize_graph.GraphRewriter( + float_graph_def, "eightbit", quantized_input_range=None, + fallback_quantization_range=[-100, 100]) eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) ops = [node.op for node in eightbit_graph_def.node] + node_names = [node.name for node in eightbit_graph_def.node] # No quantize since all inputs are const and can be quantized up-front. self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) # One dequantize at the end. self.assertEqual(1, ops.count("Dequantize")) + # The fallback constants are not in the graph. + self.assertEqual(0, node_names.count("fallback_quantization_min_value")) + self.assertEqual(0, node_names.count("fallback_quantization_max_value")) + + def test_bias_add_w_fallback_min_max_vars(self): + input_node = quantize_graph.create_constant_node( + "input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + dtype=tf.float32, shape=[1, 1, 2, 5]) + offset_node = quantize_graph.create_constant_node( + "offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5]) + bias_add_node = quantize_graph.create_node( + "BiasAdd", "bias_add", [input_node.name, offset_node.name]) + quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32) + + float_graph_def = tf.GraphDef() + float_graph_def.node.extend([input_node, offset_node, bias_add_node]) + test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True) + + # Verify there is only one Quantize, one Requantize op, and no + # RequantizationRange op. + eightbit_rewriter = quantize_graph.GraphRewriter( + float_graph_def, "eightbit", quantized_input_range=None, + fallback_quantization_range=[-.5, 15.5]) + eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name]) + + ops = [node.op for node in eightbit_graph_def.node] + node_names = [node.name for node in eightbit_graph_def.node] + # No quantize since all inputs are const and can be quantized up-front. + self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) + + # One dequantize at the end. + self.assertEqual(1, ops.count("Dequantize")) + + # No RequantizationRange + self.assertEqual(0, ops.count("RequantizationRange")) + + # The fallback constants are in the graph. + self.assertEqual(1, node_names.count("fallback_quantization_min_value")) + self.assertEqual(1, node_names.count("fallback_quantization_max_value")) + def test_remove_redundant_quantization(self): a_constant_name = "a_constant" a_constant_min_name = "a_constant_min" |