aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quantize.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize.py')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index cb66fd1f76..2ddbd73ea6 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -455,6 +455,24 @@ class _LayerMatch(object):
return self._bias_add_op
+def _FollowedByFakeQuant(tensor):
+ """Returns True if the tensor is followed by a FakeQuant."""
+ fake_quant_ops = set([
+ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
+ 'FakeQuantWithMinMaxVarsPerChannel'
+ ])
+ pass_through_ops = set(['Reshape', 'Identity'])
+ consumers = tensor.consumers()
+ while consumers:
+ c = consumers.pop()
+ if c.type in fake_quant_ops:
+ return True
+ elif c.type in pass_through_ops:
+ for output in c.outputs:
+ consumers.extend(output.consumers())
+ return False
+
+
def _InsertQuantOp(context,
name,
producer,
@@ -535,11 +553,7 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- fake_quant_ops = set([
- 'FakeQuantWithMinMaxVars',
- 'FakeQuantWithMinMaxArgs'
- ])
- if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
+ if _FollowedByFakeQuant(inputs):
return
if moving_avg: