diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-08-14 21:46:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 21:53:28 -0700 |
commit | d4b8e21ef9ac66d5871f37303218f47e40a1a02c (patch) | |
tree | 1e6562ee1667a6ae4ebad159eccc529ca1a56f9c /tensorflow/contrib/quantize | |
parent | c7ddeb5a7c572d990eb6fac6ed370d1c26deea2a (diff) |
The rewriter should not add nodes if followed pass-through op and then a FakeQuant
PiperOrigin-RevId: 208767575
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_test.py | 54 |
2 files changed, 73 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index cb66fd1f76..2ddbd73ea6 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -455,6 +455,24 @@ class _LayerMatch(object): return self._bias_add_op +def _FollowedByFakeQuant(tensor): + """Returns True if the tensor is followed by a FakeQuant.""" + fake_quant_ops = set([ + 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', + 'FakeQuantWithMinMaxVarsPerChannel' + ]) + pass_through_ops = set(['Reshape', 'Identity']) + consumers = tensor.consumers() + while consumers: + c = consumers.pop() + if c.type in fake_quant_ops: + return True + elif c.type in pass_through_ops: + for output in c.outputs: + consumers.extend(output.consumers()) + return False + + def _InsertQuantOp(context, name, producer, @@ -535,11 +553,7 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - fake_quant_ops = set([ - 'FakeQuantWithMinMaxVars', - 'FakeQuantWithMinMaxArgs' - ]) - if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): + if _FollowedByFakeQuant(inputs): return if moving_avg: diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 06ebcdfee1..212d902a3c 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue( 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + def testSkipReshapeQuantization(self): + self._RunTestOverParameters(self._TestSkipReshapeQuantization) + + def _TestSkipReshapeQuantization(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # Insert a fake quant node after the reshape. We will check that one isn't + # insert before. + array_ops.fake_quant_with_min_max_vars(reshape, -1, 1) + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertFalse( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # If no fake quant is added after the reshape, a FakeQuant should be added + # before the reshape. + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertTrue( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. |