diff options
author | Mingxing Tan <tanmingxing@google.com> | 2018-10-03 21:06:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 21:11:03 -0700 |
commit | 8a437200e14c8e09fcc8e952679d489909f175c8 (patch) | |
tree | 9b454a277941b1da74dc1466c791381cd6e544f5 /tensorflow/contrib/quantize | |
parent | 2e19f32d28ab88b5bd3dd4f6d42a54040591dfbb (diff) |
BEGIN_PUBLIC
Rollback some quantization changes that breaks some models.
END_PUBLIC
Automated rollback of commit d3f14ef70cdf113f9d330c1f7c638003429a1dc4. Revert #19894.
PiperOrigin-RevId: 215678307
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 115 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 37 |
2 files changed, 41 insertions, 111 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index afb9de8370..5e63d33db8 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -461,8 +461,8 @@ class _LayerMatch(object): return self._bias_add_op -def _GetFollowingFakeQuantOp(tensor): - """Returns the following FakeQuant op if it exists else None.""" +def _FollowedByFakeQuant(tensor): + """Returns True if the tensor is followed by a FakeQuant.""" fake_quant_ops = set([ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', 'FakeQuantWithMinMaxVarsPerChannel' @@ -472,11 +472,11 @@ def _GetFollowingFakeQuantOp(tensor): while consumers: c = consumers.pop() if c.type in fake_quant_ops: - return c + return True elif c.type in pass_through_ops: for output in c.outputs: consumers.extend(output.consumers()) - return None + return False def _InsertQuantOp(context, @@ -559,77 +559,44 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - fake_quant_op = _GetFollowingFakeQuantOp(inputs) - - # If we find that we are attempting to insert a fake quant op following - # a fake quant, we skip inserting a fake quant op - - if fake_quant_op is None: - if moving_avg: - quant = ( - quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=ema_decay, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) - else: - quant = ( - quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) - - if quant_delay and quant_delay > 0: - activate_quant = math_ops.greater_equal( - common.CreateOrGetQuantizationStep(), - quant_delay, - name=name_prefix + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=name_prefix + '/delayed_quant') + if _FollowedByFakeQuant(inputs): + return + + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) else: - # If a fake quant op is present already, make sure that - # any downstream use of the tensor reroutes to the appropriate quantized - # tensor. If there is no quant_delay, this is simply the output of the - # fake quant op. If there is a quant delay, we reroute to the output - # of the delayed quant operation, which inserts quantization only after - # a specified quant_delay - - quant = fake_quant_op.outputs[0] - if quant_delay and quant_delay > 0: - name_prefix = '/'.join(quant.name.split('/')[:-1]) - quant = quant.graph.get_tensor_by_name(name_prefix + - '/delayed_quant/Merge:0') - pruned_consumer_set = set() - for consumer in consumers: - fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0]) - if (fake_quant_dest_op is None or - fake_quant_dest_op.name != fake_quant_op.name): - pruned_consumer_set.add(consumer) - consumers = pruned_consumer_set - - # If we have - # input->pass_through->fake_quant - # there is nothing to reroute. - # - # If we have - # input-> pass_through->fake_quant - # |-> consumer - # Then we reroute such that: - # input-> pass_through->fake_quant - # |-> consumer + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + common.CreateOrGetQuantizationStep(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') + if consumers: tensors_modified_count = common.RerouteTensor( quant, inputs, can_modify=consumers) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index a9fc6c3c61..e80d2183a6 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template from tensorflow.python.platform import googletest @@ -307,42 +306,6 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): # No ops should be inserted or removed. self.assertEqual(op_names_before_rewrite, op_names_after_rewrite) - def testWithSharedWeights(self): - - self._RunTestOverAllRewrites(self._TestWithSharedWeights) - self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights) - - def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1): - self._TestWithSharedWeights(rewrite_fn, quant_delay) - - def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None): - with ops.Graph().as_default() as g: - conv = template.make_template('shared_weights_conv', self._ConvLayer) - conv() - conv() - if quant_delay is None: - rewrite_fn() - else: - rewrite_fn(quant_delay=quant_delay) - - conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D'] - weights_quants = [ - op for op in g.get_operations() - if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars' - ] - # Check that the shared weights variable is not quantized multiple times - self.assertTrue(len(weights_quants) == 1) - weights_quant_tensor = weights_quants[0].outputs[0] - if quant_delay: - delayed_weights_quants = [ - op for op in g.get_operations() - if 'weights_quant' in op.name and op.type == 'Merge' - ] - self.assertTrue(len(delayed_weights_quants) == 1) - weights_quant_tensor = delayed_weights_quants[0].outputs[0] - # Check that the Conv2D operations get the quantized weights - self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops)) - def _ConvLayer( self, input_tensor=None, scope='test', pre_activation_bypass=False, post_activation_bypass=False): |