diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index cb66fd1f76..2ddbd73ea6 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -455,6 +455,24 @@ class _LayerMatch(object): return self._bias_add_op +def _FollowedByFakeQuant(tensor): + """Returns True if the tensor is followed by a FakeQuant.""" + fake_quant_ops = set([ + 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', + 'FakeQuantWithMinMaxVarsPerChannel' + ]) + pass_through_ops = set(['Reshape', 'Identity']) + consumers = tensor.consumers() + while consumers: + c = consumers.pop() + if c.type in fake_quant_ops: + return True + elif c.type in pass_through_ops: + for output in c.outputs: + consumers.extend(output.consumers()) + return False + + def _InsertQuantOp(context, name, producer, @@ -535,11 +553,7 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - fake_quant_ops = set([ - 'FakeQuantWithMinMaxVars', - 'FakeQuantWithMinMaxArgs' - ]) - if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): + if _FollowedByFakeQuant(inputs): return if moving_avg: |