diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-30 08:23:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-30 08:26:01 -0700 |
commit | ddbb2c52db5cfab02b80b2ef563d8d6251dcfe77 (patch) | |
tree | 0aae80b864368eb4a6e90c80fb4d0a0767a4cbc4 /tensorflow/contrib/quantize | |
parent | 330c2a831dfff5640ebc2e2811749c6557f6198a (diff) |
Fix a crash in Quantize() when tf.contrib.framework.get_name_scope() == None.
PiperOrigin-RevId: 191068059
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_test.py | 21 |
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 2889016a84..d53d4d7b10 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -416,7 +416,9 @@ def _InsertQuantOp(context, # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. - name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') + name_scope = ops.get_name_scope() + if name_scope: + name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 98f05c8bfc..8d057d3710 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -247,6 +247,27 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue(not op.name.startswith('name_scope/name_scope/'), 'Broken op: %s' % op.name) + def testWithNullNameScope(self): + self._RunTestOverParameters(self._TestWithNullNameScope) + + def _TestWithNullNameScope(self, is_training): + graph = ops.Graph() + with graph.as_default(): + with graph.name_scope(None): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Passes if Quantize() does not crash. + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. |