diff options
author | manipopopo <pwmutantbread@gmail.com> | 2018-06-10 16:30:40 +0000 |
---|---|---|
committer | manipopopo <pwmutantbread@gmail.com> | 2018-09-20 08:52:25 +0000 |
commit | e514555a9572e00243083a8ec6e58c8deed5a501 (patch) | |
tree | 227960e7489061ca9a6097113249297cfd90a9c9 | |
parent | 62e41201e291b241bfad0b902ab6aa785ee06059 (diff) |
Fix routing of quantized tensors
The original tensor was not replaced with the quantized one when it had
already been quantized.
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 80 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 22 |
2 files changed, 64 insertions, 38 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index e88db0acd5..6f34308fdb 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -454,8 +454,8 @@ class _LayerMatch(object): return self._bias_add_op -def _FollowedByFakeQuant(tensor): - """Returns True if the tensor is followed by a FakeQuant.""" +def _GetFollowingFakeQuantOp(tensor): + """Returns the following FakeQuant op if it exists else None.""" fake_quant_ops = set([ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', 'FakeQuantWithMinMaxVarsPerChannel' @@ -465,11 +465,11 @@ def _FollowedByFakeQuant(tensor): while consumers: c = consumers.pop() if c.type in fake_quant_ops: - return True + return c elif c.type in pass_through_ops: for output in c.outputs: consumers.extend(output.consumers()) - return False + return None def _InsertQuantOp(context, @@ -552,43 +552,47 @@ def _InsertQuantOp(context, # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. - if _FollowedByFakeQuant(inputs): + fake_quant_op = _GetFollowingFakeQuantOp(inputs) + if fake_quant_op is not None and name == 'act_quant': return - if moving_avg: - quant = ( - quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=ema_decay, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) + if fake_quant_op is None: + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + else: + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + common.CreateOrGetQuantizationStep(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') else: - quant = ( - quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=is_training, - num_bits=bits, - narrow_range=narrow_range, - vars_collection=vars_collection, - name_prefix=name_prefix)) - - if quant_delay and quant_delay > 0: - activate_quant = math_ops.greater_equal( - common.CreateOrGetQuantizationStep(), - quant_delay, - name=name_prefix + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=name_prefix + '/delayed_quant') + quant = fake_quant_op.outputs[0] if consumers: tensors_modified_count = common.RerouteTensor( diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index e80d2183a6..d3e7264ba4 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template from tensorflow.python.platform import googletest @@ -306,6 +307,27 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): # No ops should be inserted or removed. self.assertEqual(op_names_before_rewrite, op_names_after_rewrite) + def testWithSharedWeights(self): + self._RunTestOverAllRewrites(self._TestWithSharedWeights) + + def _TestWithSharedWeights(self, rewrite_fn): + with ops.Graph().as_default() as g: + conv = template.make_template('shared_weights_conv', self._ConvLayer) + conv() + conv() + rewrite_fn() + + conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D'] + weights_quants = [ + op for op in g.get_operations() + if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars' + ] + # Check that the shared weights variable is not quantized multiple times + self.assertTrue(len(weights_quants) == 1) + # Check that the Conv2D operations get the quantized weights + weights_quant_tensor = weights_quants[0].outputs[0] + self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops)) + def _ConvLayer( self, input_tensor=None, scope='test', pre_activation_bypass=False, post_activation_bypass=False): |