aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-30 08:23:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-30 08:26:01 -0700
commitddbb2c52db5cfab02b80b2ef563d8d6251dcfe77 (patch)
tree0aae80b864368eb4a6e90c80fb4d0a0767a4cbc4 /tensorflow/contrib/quantize
parent330c2a831dfff5640ebc2e2811749c6557f6198a (diff)
Fix a crash in Quantize() when tf.contrib.framework.get_name_scope() == None.
PiperOrigin-RevId: 191068059
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py21
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2889016a84..d53d4d7b10 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -416,7 +416,9 @@ def _InsertQuantOp(context,
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
# breaks things later.
- name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/')
+ name_scope = ops.get_name_scope()
+ if name_scope:
+ name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
inputs = producer.outputs[0]
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 98f05c8bfc..8d057d3710 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -247,6 +247,27 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue(not op.name.startswith('name_scope/name_scope/'),
'Broken op: %s' % op.name)
+ def testWithNullNameScope(self):
+ self._RunTestOverParameters(self._TestWithNullNameScope)
+
+ def _TestWithNullNameScope(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ with graph.name_scope(None):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ _ = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Passes if Quantize() does not crash.
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.