diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 22:56:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 22:56:16 -0700 |
commit | d3f14ef70cdf113f9d330c1f7c638003429a1dc4 (patch) | |
tree | a43d6b5c50a81455147e620f7791ed21dd9b8d1c /tensorflow/contrib/quantize | |
parent | 5df53ab7eb81c67459e2a95e8fbcb71999c703ad (diff) | |
parent | f44805f8333aaf76d392bb565fe2381be07ccf2a (diff) |
Merge pull request #19894 from manipopopo:fix_quantize
PiperOrigin-RevId: 214724610
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 115 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 37 |
2 files changed, 111 insertions, 41 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 5e63d33db8..afb9de8370 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -461,8 +461,8 @@ class _LayerMatch(object): return self._bias_add_op -def _FollowedByFakeQuant(tensor): - """Returns True if the tensor is followed by a FakeQuant.""" +def _GetFollowingFakeQuantOp(tensor): + """Returns the following FakeQuant op if it exists else None.""" fake_quant_ops = set([ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', 'FakeQuantWithMinMaxVarsPerChannel' @@ -472,11 +472,11 @@ def _FollowedByFakeQuant(tensor): while consumers: c = consumers.pop() if c.type in fake_quant_ops: - return True + return c elif c.type in pass_through_ops: for output in c.outputs: consumers.extend(output.consumers()) - return False + return None def _InsertQuantOp(context, @@ -559,44 +559,77 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - if _FollowedByFakeQuant(inputs): - return - - if moving_avg: - quant = ( - quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=ema_decay, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) + fake_quant_op = _GetFollowingFakeQuantOp(inputs) + + # If we find that we are attempting to insert a fake quant op following + # a fake quant, we skip inserting a fake quant op + + if fake_quant_op is None: + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + else: + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + common.CreateOrGetQuantizationStep(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') else: - quant = ( - quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) - - if quant_delay and quant_delay > 0: - activate_quant = math_ops.greater_equal( - common.CreateOrGetQuantizationStep(), - quant_delay, - name=name_prefix + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=name_prefix + '/delayed_quant') - + # If a fake quant op is present already, make sure that + # any downstream use of the tensor reroutes to the appropriate quantized + # tensor. If there is no quant_delay, this is simply the output of the + # fake quant op. If there is a quant delay, we reroute to the output + # of the delayed quant operation, which inserts quantization only after + # a specified quant_delay + + quant = fake_quant_op.outputs[0] + if quant_delay and quant_delay > 0: + name_prefix = '/'.join(quant.name.split('/')[:-1]) + quant = quant.graph.get_tensor_by_name(name_prefix + + '/delayed_quant/Merge:0') + pruned_consumer_set = set() + for consumer in consumers: + fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0]) + if (fake_quant_dest_op is None or + fake_quant_dest_op.name != fake_quant_op.name): + pruned_consumer_set.add(consumer) + consumers = pruned_consumer_set + + # If we have + # input->pass_through->fake_quant + # there is nothing to reroute. + # + # If we have + # input-> pass_through->fake_quant + # |-> consumer + # Then we reroute such that: + # input-> pass_through->fake_quant + # |-> consumer if consumers: tensors_modified_count = common.RerouteTensor( quant, inputs, can_modify=consumers) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index e80d2183a6..a9fc6c3c61 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template from tensorflow.python.platform import googletest @@ -306,6 +307,42 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): # No ops should be inserted or removed. self.assertEqual(op_names_before_rewrite, op_names_after_rewrite) + def testWithSharedWeights(self): + + self._RunTestOverAllRewrites(self._TestWithSharedWeights) + self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights) + + def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1): + self._TestWithSharedWeights(rewrite_fn, quant_delay) + + def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None): + with ops.Graph().as_default() as g: + conv = template.make_template('shared_weights_conv', self._ConvLayer) + conv() + conv() + if quant_delay is None: + rewrite_fn() + else: + rewrite_fn(quant_delay=quant_delay) + + conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D'] + weights_quants = [ + op for op in g.get_operations() + if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars' + ] + # Check that the shared weights variable is not quantized multiple times + self.assertTrue(len(weights_quants) == 1) + weights_quant_tensor = weights_quants[0].outputs[0] + if quant_delay: + delayed_weights_quants = [ + op for op in g.get_operations() + if 'weights_quant' in op.name and op.type == 'Merge' + ] + self.assertTrue(len(delayed_weights_quants) == 1) + weights_quant_tensor = delayed_weights_quants[0].outputs[0] + # Check that the Conv2D operations get the quantized weights + self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops)) + def _ConvLayer( self, input_tensor=None, scope='test', pre_activation_bypass=False, post_activation_bypass=False): |