aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 22:56:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 22:56:16 -0700
commitd3f14ef70cdf113f9d330c1f7c638003429a1dc4 (patch)
treea43d6b5c50a81455147e620f7791ed21dd9b8d1c /tensorflow/contrib/quantize
parent5df53ab7eb81c67459e2a95e8fbcb71999c703ad (diff)
parentf44805f8333aaf76d392bb565fe2381be07ccf2a (diff)
Merge pull request #19894 from manipopopo:fix_quantize
PiperOrigin-RevId: 214724610
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py115
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py37
2 files changed, 111 insertions, 41 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 5e63d33db8..afb9de8370 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -461,8 +461,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _FollowedByFakeQuant(tensor):
- """Returns True if the tensor is followed by a FakeQuant."""
+def _GetFollowingFakeQuantOp(tensor):
+ """Returns the following FakeQuant op if it exists else None."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -472,11 +472,11 @@ def _FollowedByFakeQuant(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return True
+ return c
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return False
+ return None
def _InsertQuantOp(context,
@@ -559,44 +559,77 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- if _FollowedByFakeQuant(inputs):
- return
-
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
+ fake_quant_op = _GetFollowingFakeQuantOp(inputs)
+
+ # If we find that we are attempting to insert a fake quant op following
+ # a fake quant, we skip inserting a fake quant op
+
+ if fake_quant_op is None:
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+ else:
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
-
+ # If a fake quant op is present already, make sure that
+ # any downstream use of the tensor reroutes to the appropriate quantized
+ # tensor. If there is no quant_delay, this is simply the output of the
+ # fake quant op. If there is a quant delay, we reroute to the output
+ # of the delayed quant operation, which inserts quantization only after
+ # a specified quant_delay
+
+ quant = fake_quant_op.outputs[0]
+ if quant_delay and quant_delay > 0:
+ name_prefix = '/'.join(quant.name.split('/')[:-1])
+ quant = quant.graph.get_tensor_by_name(name_prefix +
+ '/delayed_quant/Merge:0')
+ pruned_consumer_set = set()
+ for consumer in consumers:
+ fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
+ if (fake_quant_dest_op is None or
+ fake_quant_dest_op.name != fake_quant_op.name):
+ pruned_consumer_set.add(consumer)
+ consumers = pruned_consumer_set
+
+ # If we have
+ # input->pass_through->fake_quant
+ # there is nothing to reroute.
+ #
+ # If we have
+ # input-> pass_through->fake_quant
+ # |-> consumer
+ # Then we reroute such that:
+ # input-> pass_through->fake_quant
+ # |-> consumer
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index e80d2183a6..a9fc6c3c61 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -306,6 +307,42 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+ def testWithSharedWeights(self):
+
+ self._RunTestOverAllRewrites(self._TestWithSharedWeights)
+ self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
+
+ def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
+ self._TestWithSharedWeights(rewrite_fn, quant_delay)
+
+ def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
+ with ops.Graph().as_default() as g:
+ conv = template.make_template('shared_weights_conv', self._ConvLayer)
+ conv()
+ conv()
+ if quant_delay is None:
+ rewrite_fn()
+ else:
+ rewrite_fn(quant_delay=quant_delay)
+
+ conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
+ weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
+ ]
+ # Check that the shared weights variable is not quantized multiple times
+ self.assertTrue(len(weights_quants) == 1)
+ weights_quant_tensor = weights_quants[0].outputs[0]
+ if quant_delay:
+ delayed_weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'Merge'
+ ]
+ self.assertTrue(len(delayed_weights_quants) == 1)
+ weights_quant_tensor = delayed_weights_quants[0].outputs[0]
+ # Check that the Conv2D operations get the quantized weights
+ self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
+
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):