aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-16 10:29:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-16 10:33:24 -0700
commit4091e498ba8dedc8f4ad5952dfe1262e735e7f42 (patch)
treeb45d4d81d2c3845a11fbb1999a7431473bd74d89 /tensorflow/contrib/quantize
parent86398ed80a09030255226678eee5d4be583a61c4 (diff)
Fix naming BatchNorm_Fold//batch_norm_correction -> BatchNorm_Fold/batch_norm_correction.
PiperOrigin-RevId: 189358090
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py3
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py12
2 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index b278265639..e8a0d41425 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -317,7 +317,8 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
"""
g = ops.get_default_graph()
- with g.name_scope(context + '/batch_norm_correction'):
+ prefix = '' if not context else context + '/'
+ with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index c90a18ab03..af31467476 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -128,6 +128,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def testFoldConv2d(self):
self._RunTestOverParameters(self._TestFoldConv2d)
@@ -196,6 +199,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def testFoldConv2dUnknownShape(self):
self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
@@ -260,6 +266,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def testFoldFullyConnectedLayer(self):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
@@ -337,6 +346,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def testFoldDepthwiseConv2d(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)