aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar manipopopo <pwmutantbread@gmail.com>2018-06-10 16:30:40 +0000
committerGravatar manipopopo <pwmutantbread@gmail.com>2018-09-20 08:52:25 +0000
commite514555a9572e00243083a8ec6e58c8deed5a501 (patch)
tree227960e7489061ca9a6097113249297cfd90a9c9
parent62e41201e291b241bfad0b902ab6aa785ee06059 (diff)
Fix routing of quantized tensors
The original tensor was not replaced with the quantized one when it had already been quantized.
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py80
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py22
2 files changed, 64 insertions, 38 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0acd5..6f34308fdb 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -454,8 +454,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _FollowedByFakeQuant(tensor):
- """Returns True if the tensor is followed by a FakeQuant."""
+def _GetFollowingFakeQuantOp(tensor):
+ """Returns the following FakeQuant op if it exists else None."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -465,11 +465,11 @@ def _FollowedByFakeQuant(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return True
+ return c
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return False
+ return None
def _InsertQuantOp(context,
@@ -552,43 +552,47 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- if _FollowedByFakeQuant(inputs):
+ fake_quant_op = _GetFollowingFakeQuantOp(inputs)
+ if fake_quant_op is not None and name == 'act_quant':
return
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
+ if fake_quant_op is None:
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+ else:
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
+ quant = fake_quant_op.outputs[0]
if consumers:
tensors_modified_count = common.RerouteTensor(
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index e80d2183a6..d3e7264ba4 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -306,6 +307,27 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+ def testWithSharedWeights(self):
+ self._RunTestOverAllRewrites(self._TestWithSharedWeights)
+
+ def _TestWithSharedWeights(self, rewrite_fn):
+ with ops.Graph().as_default() as g:
+ conv = template.make_template('shared_weights_conv', self._ConvLayer)
+ conv()
+ conv()
+ rewrite_fn()
+
+ conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
+ weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
+ ]
+ # Check that the shared weights variable is not quantized multiple times
+ self.assertTrue(len(weights_quants) == 1)
+ # Check that the Conv2D operations get the quantized weights
+ weights_quant_tensor = weights_quants[0].outputs[0]
+ self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
+
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):