aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quantize_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize_test.py')
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 06ebcdfee1..212d902a3c 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue(
'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names)
+ def testSkipReshapeQuantization(self):
+ self._RunTestOverParameters(self._TestSkipReshapeQuantization)
+
+ def _TestSkipReshapeQuantization(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # Insert a fake quant node after the reshape. We will check that one isn't
+ # insert before.
+ array_ops.fake_quant_with_min_max_vars(reshape, -1, 1)
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertFalse(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # If no fake quant is added after the reshape, a FakeQuant should be added
+ # before the reshape.
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertTrue(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.