diff options
author | 2018-03-20 07:27:16 -0700 | |
---|---|---|
committer | 2018-03-20 07:31:58 -0700 | |
commit | a40c8024f9beec346c2c1d98e9238c5d48ea0dca (patch) | |
tree | 1fd1b4d17d4c6100d6843014cdccfc7580222555 /tensorflow/contrib/quantize | |
parent | e28a79eae228be8e65b5dff8bb8aa5ee2f41f70a (diff) |
Drop name_scope from operation names during quantization to avoid doubling it up.
PiperOrigin-RevId: 189737746
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/common.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_test.py | 24 |
3 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index 3138149468..bf648e158e 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -123,3 +123,11 @@ def CreateOrGetQuantizationStep(): # normal variables to return a tensor of the same name. return array_ops.identity( state_ops.assign_add(quantization_step_tensor, 1)) + + +def DropStringPrefix(s, prefix): + """If the string starts with this prefix, drops it.""" + if s.startswith(prefix): + return s[len(prefix):] + else: + return s diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 9780e6dbcc..2b5b877e8e 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -367,6 +367,12 @@ def _InsertQuantOp(context, consumer operation. """ name_prefix = _AddContextToName(context, name) + # This is needed on TPU where name_scope == 'TPUReplicate/loop', and + # name_prefix starts with 'TPUReplicate/loop/'; without dropping it + # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which + # breaks things later. + name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') + inputs = producer.outputs[0] if moving_avg: quant = ( diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 8e60f4b661..216310abe4 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -164,6 +164,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue('FakeQuantWithMinMaxVars' in [i.op.type for i in bypass_tensor.op.inputs]) + def testWithNameScope(self): + self._RunTestOverParameters(self._TestWithNameScope) + + def _TestWithNameScope(self, is_training): + graph = ops.Graph() + with graph.as_default(): + with graph.name_scope('name_scope'): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + for op in graph.get_operations(): + self.assertTrue(not op.name.startswith('name_scope/name_scope/'), + 'Broken op: %s' % op.name) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. |