aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Raghuraman Krishnamoorthi <raghuramank@google.com>2018-08-01 15:58:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 16:02:12 -0700
commit626317cb35524e4deb0851a65ca5dd5ca61d431f (patch)
tree4faa9c3c53cc6bd0b43ad9aa0b4e2b1621c6c80f /tensorflow/contrib/quantize
parent2ff0263775f82f1951a344ebfff46a6b233bf266 (diff)
Generalize quantization rewriter to handle seperable convolutions. Insert fake quant ops for weights in both depthwise and regular convolutions inside a seperable convolution op. Also insert fake quant ops for activations produced by first depthwise convolution
PiperOrigin-RevId: 207009650
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py21
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py55
2 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 4fc315d901..903faeff11 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -261,6 +261,16 @@ def _FindLayersToQuantize(graph):
layer_output_pattern = graph_matcher.OneofPattern(
[batch_to_space_pattern, layer_pattern])
+
+ # For separable convolutions, we are looking for a conv, followed by a conv
+ # with no activations between the two.
+ sep_conv_pattern = graph_matcher.OpTypePattern(
+ '|'.join(_QUANTIZABLE_TYPES),
+ inputs=[
+ graph_matcher.OneofPattern([layer_output_pattern]),
+ graph_matcher.OpTypePattern('*')
+ ],
+ ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul',
inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
@@ -393,6 +403,17 @@ def _FindLayersToQuantize(graph):
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
+ # Look for separable convolutions here
+ sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
+ for match_result in sep_conv_matcher.match_graph(graph):
+ layer_op = match_result.get_op(layer_pattern)
+ weight_tensor = match_result.get_tensor(weight_identity_pattern)
+ activation_op = match_result.get_op(layer_pattern)
+ if layer_op not in matched_layer_set:
+ matched_layer_set.add(layer_op)
+ layer_matches.append(
+ _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
+
return layer_matches
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 92ca4a1b0c..98209fffb9 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Check if output of bias add is quantized
+ quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # identity op isn't in the consumers of the operation.
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+ self.assertNotIn('test/relu6', [c.name for c in consumers])
+
+ def testInsertQuantOpInSeparableConv2d(self):
+ self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d)
+
+ def _TestInsertQuantOpInSeparableConv2d(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
+ conv = separable_conv2d(
+ input1,
+ 3, [5, 5],
+ stride=2,
+ depth_multiplier=1.0,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test/test')
+ node = math_ops.add(conv, input2, name='test/add')
+ node = nn_ops.relu6(node, name='test/relu6')
+ update_barrier = control_flow_ops.no_op(name='update_barrier')
+ with ops.control_dependencies([update_barrier]):
+ array_ops.identity(node, name='control_dependency')
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Check if output of bias add is quantized
quantization_node_name = 'FakeQuantWithMinMaxVars'
conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
+ # Check if weights for both convs inside seperable conv are quantized
+ pointwise_weight_quant = graph.get_operation_by_name(
+ 'test/test/weights_quant/' + quantization_node_name)
+ self.assertEqual(pointwise_weight_quant.type, quantization_node_name)
+ depthwise_weight_quant = graph.get_operation_by_name(
+ 'test/test/separable_conv2d/weights_quant/' + quantization_node_name)
+ self.assertEqual(depthwise_weight_quant.type, quantization_node_name)
+
+ # Check if activations after first depthwise conv are quantized.
+ depthwise_act_quant = graph.get_operation_by_name(
+ 'test/test/separable_conv2d/act_quant/' + quantization_node_name)
+ self.assertEqual(depthwise_act_quant.type, quantization_node_name)
+
for op in graph.get_operations():
if op.type == quantization_node_name:
quant_op = graph.get_operation_by_name(op.name)