diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize_test.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_test.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 06ebcdfee1..212d902a3c 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue( 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) + def testSkipReshapeQuantization(self): + self._RunTestOverParameters(self._TestSkipReshapeQuantization) + + def _TestSkipReshapeQuantization(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # Insert a fake quant node after the reshape. We will check that one isn't + # insert before. + array_ops.fake_quant_with_min_max_vars(reshape, -1, 1) + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertFalse( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + scope='test/test') + + reshape = array_ops.reshape( + conv, (int(10), int(height / 2), int(width / 2), int(16))) + + # If no fake quant is added after the reshape, a FakeQuant should be added + # before the reshape. + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that there isn't a FakeQuant added before the reshape. + self.assertTrue( + 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. |