diff options
author | manipopopo <pwmutantbread@gmail.com> | 2018-06-11 14:19:48 +0000 |
---|---|---|
committer | manipopopo <pwmutantbread@gmail.com> | 2018-09-20 08:53:47 +0000 |
commit | f44805f8333aaf76d392bb565fe2381be07ccf2a (patch) | |
tree | 15f6226c7a11a3e0a1ab0397fcc6241b21caed28 /tensorflow/contrib/quantize | |
parent | e514555a9572e00243083a8ec6e58c8deed5a501 (diff) |
Fix routing of delayed quantized tensors
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 19 |
2 files changed, 20 insertions, 3 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 6f34308fdb..ccf58c7a8a 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -593,6 +593,10 @@ def _InsertQuantOp(context, name=name_prefix + '/delayed_quant') else: quant = fake_quant_op.outputs[0] + if quant_delay and quant_delay > 0: + name_prefix = '/'.join(quant.name.split('/')[:-1]) + quant = quant.graph.get_tensor_by_name( + name_prefix + '/delayed_quant/Merge:0') if consumers: tensors_modified_count = common.RerouteTensor( diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index d3e7264ba4..36d87039a5 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -309,13 +309,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): def testWithSharedWeights(self): self._RunTestOverAllRewrites(self._TestWithSharedWeights) + self._RunTestOverTrainingRewrites( + lambda rewrite_fn: self._TestWithSharedWeights(rewrite_fn, + quant_delay=1)) - def _TestWithSharedWeights(self, rewrite_fn): + def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None): with ops.Graph().as_default() as g: conv = template.make_template('shared_weights_conv', self._ConvLayer) conv() conv() - rewrite_fn() + if quant_delay is None: + rewrite_fn() + else: + rewrite_fn(quant_delay=quant_delay) conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D'] weights_quants = [ @@ -324,8 +330,15 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): ] # Check that the shared weights variable is not quantized multiple times self.assertTrue(len(weights_quants) == 1) - # Check that the Conv2D operations get the quantized weights weights_quant_tensor = weights_quants[0].outputs[0] + if quant_delay: + delayed_weights_quants = [ + op for op in g.get_operations() + if 'weights_quant' in op.name and op.type == 'Merge' + ] + self.assertTrue(len(delayed_weights_quants) == 1) + weights_quant_tensor = delayed_weights_quants[0].outputs[0] + # Check that the Conv2D operations get the quantized weights self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops)) def _ConvLayer( |