aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar manipopopo <pwmutantbread@gmail.com>2018-06-11 14:19:48 +0000
committerGravatar manipopopo <pwmutantbread@gmail.com>2018-09-20 08:53:47 +0000
commitf44805f8333aaf76d392bb565fe2381be07ccf2a (patch)
tree15f6226c7a11a3e0a1ab0397fcc6241b21caed28 /tensorflow/contrib/quantize
parente514555a9572e00243083a8ec6e58c8deed5a501 (diff)
Fix routing of delayed quantized tensors
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py19
2 files changed, 20 insertions, 3 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 6f34308fdb..ccf58c7a8a 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -593,6 +593,10 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
else:
quant = fake_quant_op.outputs[0]
+ if quant_delay and quant_delay > 0:
+ name_prefix = '/'.join(quant.name.split('/')[:-1])
+ quant = quant.graph.get_tensor_by_name(
+ name_prefix + '/delayed_quant/Merge:0')
if consumers:
tensors_modified_count = common.RerouteTensor(
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index d3e7264ba4..36d87039a5 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -309,13 +309,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
def testWithSharedWeights(self):
self._RunTestOverAllRewrites(self._TestWithSharedWeights)
+ self._RunTestOverTrainingRewrites(
+ lambda rewrite_fn: self._TestWithSharedWeights(rewrite_fn,
+ quant_delay=1))
- def _TestWithSharedWeights(self, rewrite_fn):
+ def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
with ops.Graph().as_default() as g:
conv = template.make_template('shared_weights_conv', self._ConvLayer)
conv()
conv()
- rewrite_fn()
+ if quant_delay is None:
+ rewrite_fn()
+ else:
+ rewrite_fn(quant_delay=quant_delay)
conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
weights_quants = [
@@ -324,8 +330,15 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
]
# Check that the shared weights variable is not quantized multiple times
self.assertTrue(len(weights_quants) == 1)
- # Check that the Conv2D operations get the quantized weights
weights_quant_tensor = weights_quants[0].outputs[0]
+ if quant_delay:
+ delayed_weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'Merge'
+ ]
+ self.assertTrue(len(delayed_weights_quants) == 1)
+ weights_quant_tensor = delayed_weights_quants[0].outputs[0]
+ # Check that the Conv2D operations get the quantized weights
self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
def _ConvLayer(