aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 07:27:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 07:31:58 -0700
commita40c8024f9beec346c2c1d98e9238c5d48ea0dca (patch)
tree1fd1b4d17d4c6100d6843014cdccfc7580222555 /tensorflow/contrib/quantize
parente28a79eae228be8e65b5dff8bb8aa5ee2f41f70a (diff)
Drop name_scope from operation names during quantization to avoid doubling it up.
PiperOrigin-RevId: 189737746
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/common.py8
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py6
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py24
3 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index 3138149468..bf648e158e 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -123,3 +123,11 @@ def CreateOrGetQuantizationStep():
# normal variables to return a tensor of the same name.
return array_ops.identity(
state_ops.assign_add(quantization_step_tensor, 1))
+
+
+def DropStringPrefix(s, prefix):
+ """If the string starts with this prefix, drops it."""
+ if s.startswith(prefix):
+ return s[len(prefix):]
+ else:
+ return s
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 9780e6dbcc..2b5b877e8e 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -367,6 +367,12 @@ def _InsertQuantOp(context,
consumer operation.
"""
name_prefix = _AddContextToName(context, name)
+ # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
+ # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
+ # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
+ # breaks things later.
+ name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/')
+
inputs = producer.outputs[0]
if moving_avg:
quant = (
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 8e60f4b661..216310abe4 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -164,6 +164,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue('FakeQuantWithMinMaxVars' in
[i.op.type for i in bypass_tensor.op.inputs])
+ def testWithNameScope(self):
+ self._RunTestOverParameters(self._TestWithNameScope)
+
+ def _TestWithNameScope(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ with graph.name_scope('name_scope'):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ _ = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ for op in graph.get_operations():
+ self.assertTrue(not op.name.startswith('name_scope/name_scope/'),
+ 'Broken op: %s' % op.name)
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.