aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize_parameterized_test.py')
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py76
1 files changed, 74 insertions, 2 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 5e3af0a567..31a2955ddb 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -654,8 +654,80 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph_def_after = str(graph.as_graph_def())
self.assertEqual(graph_def_before, graph_def_after)
- def _BatchNormParams(self, fused=False):
- return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
+ def testBatchNormForcedUpdates(self):
+ parameter_list = [
+ # (activation, activation_op_name, fused_batch_norm)
+ (nn_ops.relu6, 'Relu6', False),
+ (nn_ops.relu, 'Relu', False),
+ (array_ops.identity, 'Identity', False),
+ (nn_ops.relu6, 'Relu6', True),
+ (nn_ops.relu, 'Relu', True),
+ (array_ops.identity, 'Identity', True),
+ ]
+ for params in parameter_list:
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False)
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True)
+
+ def _TestBatchNormForcedUpdates(self, activation, activation_op_name,
+ fused_batch_norm, use_resource):
+ """post_activation bypass quantization should happen with forced updates."""
+ graph = ops.Graph()
+ with graph.as_default():
+ variable_scope.get_variable_scope().set_use_resource(use_resource)
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
+ # Setting updates_collections to None forces updates adding an extra
+ # identity operation following batch norms.
+ bn_params = self._BatchNormParams(
+ fused=fused_batch_norm, force_updates=True)
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation,
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ scope='test/test')
+ bypass_tensor = math_ops.add(conv, input2, name='test/add')
+ # The output of the post_activation bypass will be another layer.
+ _ = conv2d(
+ bypass_tensor,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ activation_fn=activation,
+ scope='test/unused')
+
+ fold_batch_norms.FoldBatchNorms(graph, is_training=True)
+ quantize.Quantize(graph, is_training=True)
+
+ # Ensure that the bypass node is preceded by and followed by a
+ # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
+ # activation.
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [c.type for c in bypass_tensor.consumers()])
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [i.op.type for i in bypass_tensor.op.inputs])
+
+ with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
+ f.write(str(graph.as_graph_def()))
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ if force_updates:
+ params['updates_collections'] = None
+ return params
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.