aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-10-03 21:06:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 21:11:03 -0700
commit8a437200e14c8e09fcc8e952679d489909f175c8 (patch)
tree9b454a277941b1da74dc1466c791381cd6e544f5 /tensorflow/contrib/quantize
parent2e19f32d28ab88b5bd3dd4f6d42a54040591dfbb (diff)
BEGIN_PUBLIC
Rollback some quantization changes that breaks some models. END_PUBLIC Automated rollback of commit d3f14ef70cdf113f9d330c1f7c638003429a1dc4. Revert #19894. PiperOrigin-RevId: 215678307
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py115
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py37
2 files changed, 41 insertions, 111 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index afb9de8370..5e63d33db8 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -461,8 +461,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _GetFollowingFakeQuantOp(tensor):
- """Returns the following FakeQuant op if it exists else None."""
+def _FollowedByFakeQuant(tensor):
+ """Returns True if the tensor is followed by a FakeQuant."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -472,11 +472,11 @@ def _GetFollowingFakeQuantOp(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return c
+ return True
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return None
+ return False
def _InsertQuantOp(context,
@@ -559,77 +559,44 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- fake_quant_op = _GetFollowingFakeQuantOp(inputs)
-
- # If we find that we are attempting to insert a fake quant op following
- # a fake quant, we skip inserting a fake quant op
-
- if fake_quant_op is None:
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
- else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
+ if _FollowedByFakeQuant(inputs):
+ return
+
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
else:
- # If a fake quant op is present already, make sure that
- # any downstream use of the tensor reroutes to the appropriate quantized
- # tensor. If there is no quant_delay, this is simply the output of the
- # fake quant op. If there is a quant delay, we reroute to the output
- # of the delayed quant operation, which inserts quantization only after
- # a specified quant_delay
-
- quant = fake_quant_op.outputs[0]
- if quant_delay and quant_delay > 0:
- name_prefix = '/'.join(quant.name.split('/')[:-1])
- quant = quant.graph.get_tensor_by_name(name_prefix +
- '/delayed_quant/Merge:0')
- pruned_consumer_set = set()
- for consumer in consumers:
- fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
- if (fake_quant_dest_op is None or
- fake_quant_dest_op.name != fake_quant_op.name):
- pruned_consumer_set.add(consumer)
- consumers = pruned_consumer_set
-
- # If we have
- # input->pass_through->fake_quant
- # there is nothing to reroute.
- #
- # If we have
- # input-> pass_through->fake_quant
- # |-> consumer
- # Then we reroute such that:
- # input-> pass_through->fake_quant
- # |-> consumer
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
+
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index a9fc6c3c61..e80d2183a6 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -307,42 +306,6 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
- def testWithSharedWeights(self):
-
- self._RunTestOverAllRewrites(self._TestWithSharedWeights)
- self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
-
- def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
- self._TestWithSharedWeights(rewrite_fn, quant_delay)
-
- def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
- with ops.Graph().as_default() as g:
- conv = template.make_template('shared_weights_conv', self._ConvLayer)
- conv()
- conv()
- if quant_delay is None:
- rewrite_fn()
- else:
- rewrite_fn(quant_delay=quant_delay)
-
- conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
- weights_quants = [
- op for op in g.get_operations()
- if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
- ]
- # Check that the shared weights variable is not quantized multiple times
- self.assertTrue(len(weights_quants) == 1)
- weights_quant_tensor = weights_quants[0].outputs[0]
- if quant_delay:
- delayed_weights_quants = [
- op for op in g.get_operations()
- if 'weights_quant' in op.name and op.type == 'Merge'
- ]
- self.assertTrue(len(delayed_weights_quants) == 1)
- weights_quant_tensor = delayed_weights_quants[0].outputs[0]
- # Check that the Conv2D operations get the quantized weights
- self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
-
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):