aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-14 21:46:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 21:53:28 -0700
commitd4b8e21ef9ac66d5871f37303218f47e40a1a02c (patch)
tree1e6562ee1667a6ae4ebad159eccc529ca1a56f9c /tensorflow/contrib/quantize
parentc7ddeb5a7c572d990eb6fac6ed370d1c26deea2a (diff)
The rewriter should not add nodes if followed pass-through op and then a FakeQuant
PiperOrigin-RevId: 208767575
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py24
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py54
2 files changed, 73 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index cb66fd1f76..2ddbd73ea6 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -455,6 +455,24 @@ class _LayerMatch(object):
return self._bias_add_op
+def _FollowedByFakeQuant(tensor):
+ """Returns True if the tensor is followed by a FakeQuant."""
+ fake_quant_ops = set([
+ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
+ 'FakeQuantWithMinMaxVarsPerChannel'
+ ])
+ pass_through_ops = set(['Reshape', 'Identity'])
+ consumers = tensor.consumers()
+ while consumers:
+ c = consumers.pop()
+ if c.type in fake_quant_ops:
+ return True
+ elif c.type in pass_through_ops:
+ for output in c.outputs:
+ consumers.extend(output.consumers())
+ return False
+
+
def _InsertQuantOp(context,
name,
producer,
@@ -535,11 +553,7 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- fake_quant_ops = set([
- 'FakeQuantWithMinMaxVars',
- 'FakeQuantWithMinMaxArgs'
- ])
- if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
+ if _FollowedByFakeQuant(inputs):
return
if moving_avg:
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 06ebcdfee1..212d902a3c 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue(
'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names)
+ def testSkipReshapeQuantization(self):
+ self._RunTestOverParameters(self._TestSkipReshapeQuantization)
+
+ def _TestSkipReshapeQuantization(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # Insert a fake quant node after the reshape. We will check that one isn't
+ # insert before.
+ array_ops.fake_quant_with_min_max_vars(reshape, -1, 1)
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertFalse(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # If no fake quant is added after the reshape, a FakeQuant should be added
+ # before the reshape.
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertTrue(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.