aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-04-23 15:57:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 16:04:56 -0700
commit762fa5f6ead8f662e5cc14420293cb369f2b9615 (patch)
treec2e7b118d00193c3d6df879bb582699e2d6d10b2 /tensorflow/contrib/quantize
parentff15c81e2b92ef8fb47bb15790cffd18377a4ef2 (diff)
FakeQuant operations before ReLUs (occurs after bypass nodes) aren't needed.
PiperOrigin-RevId: 193999591
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py68
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py14
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py57
3 files changed, 87 insertions, 52 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index d2d0426d23..efc1a94b3c 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -133,19 +133,27 @@ def Quantize(graph,
bits=activation_bits,
producer_scope=scope,
consumer_scope=scope)
- _InsertQuantOp(
- add_context,
- 'add_quant',
- layer_match.bypass_op,
- input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
- is_training,
- moving_avg=True,
- ema_decay=ema_decay,
- quant_delay=quant_delay,
- vars_collection=vars_collection,
- bits=activation_bits,
- producer_scope=scope,
- consumer_scope=scope)
+ # Make sure the op following this isn't an activation. In which case, we
+ # shouldn't quantize it, since the activation will be Fused into the
+ # Add at inference time.
+ consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
+ if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
+ logging.info('Skipping %s, because its followed by an activation.',
+ layer_match.bypass_op.name)
+ else:
+ _InsertQuantOp(
+ add_context,
+ 'add_quant',
+ layer_match.bypass_op,
+ input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
+ is_training,
+ moving_avg=True,
+ ema_decay=ema_decay,
+ quant_delay=quant_delay,
+ vars_collection=vars_collection,
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
@@ -153,19 +161,27 @@ def Quantize(graph,
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
# If `scope` is given, only quantize it if the producer is in the right
# scope.
- _InsertQuantOp(
- post_activation_bypass_context,
- 'post_activation_bypass_quant',
- layer_match.post_activation_bypass_op,
- input_to_ops_map.ConsumerOperations(
- layer_match.post_activation_bypass_op),
- is_training,
- moving_avg=True,
- ema_decay=ema_decay,
- quant_delay=quant_delay,
- vars_collection=vars_collection,
- bits=activation_bits,
- producer_scope=scope)
+ # Make sure the op following this isn't an activation. In which case, we
+ # shouldn't quantize it, since the activation will be Fused into the
+ # Add at inference time.
+ consumers = input_to_ops_map.ConsumerOperations(
+ layer_match.post_activation_bypass_op)
+ if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
+ logging.info('Skipping %s, because its followed by an activation.',
+ layer_match.post_activation_bypass_op.name)
+ else:
+ _InsertQuantOp(
+ post_activation_bypass_context,
+ 'post_activation_bypass_quant',
+ layer_match.post_activation_bypass_op,
+ consumers,
+ is_training,
+ moving_avg=True,
+ ema_decay=ema_decay,
+ quant_delay=quant_delay,
+ vars_collection=vars_collection,
+ bits=activation_bits,
+ producer_scope=scope)
def _FindLayersToQuantize(graph):
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index caf8ff28d5..54faf582f1 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -113,20 +113,6 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
- def testWithPreActivationBypass(self):
- self._RunTestOverAllRewrites(self._TestWithPreActivationBypass)
-
- def _TestWithPreActivationBypass(self, rewrite_fn):
- # Tests that the default graph is correctly used when no args are provided
- # to rewrite_fn.
- with ops.Graph().as_default() as g:
- self._ConvLayer(pre_activation_bypass=True, scope='scope1')
- rewrite_fn()
-
- op_names = [op.name for op in g.get_operations()]
- self.assertTrue(
- any('scope1/add_quant/' in name for name in op_names))
-
def testWithPostActivationBypass(self):
self._RunTestOverAllRewrites(self._TestWithPostActivationBypass)
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index d37c83d683..5e479f3946 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -82,9 +82,22 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
- add_quant = graph.get_operation_by_name('test/add_quant/' +
- quantization_node_name)
- self.assertEqual(add_quant.type, quantization_node_name)
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # isn't in the consumers of the operation. Since activations are folded
+ # the preceding operation during inference, the FakeQuant operation after
+ # the activation is all that is needed.
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+
+ self.assertNotIn('test/identity', [c.name for c in consumers])
def testInsertQuantOpForAddAfterSeparableConv2d(self):
self._RunTestOverParameters(
@@ -109,9 +122,20 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
- add_quant = graph.get_operation_by_name('test/add_quant/' +
- quantization_node_name)
- self.assertEqual(add_quant.type, quantization_node_name)
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # identity op isn't in the consumers of the operation.
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+
+ self.assertNotIn('test/identity', [c.name for c in consumers])
def testFinalLayerQuantized(self):
self._RunTestOverParameters(self._TestFinalLayerQuantized)
@@ -153,12 +177,21 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=array_ops.identity,
scope='test/test')
bypass_tensor = math_ops.add(conv, input2, name='test/add')
- _ = array_ops.identity(bypass_tensor, name='test/output')
+ # The output of the post_activation bypass will be another layer.
+ _ = conv2d(
+ bypass_tensor,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=array_ops.identity,
+ scope='test/unused')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
- # Ensure that the bypass node is preceded and followed by
- # FakeQuantWithMinMaxVars operations.
+ # Ensure that the bypass node is preceded by and followed by a
+ # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
+ # activation.
self.assertTrue('FakeQuantWithMinMaxVars' in
[c.type for c in bypass_tensor.consumers()])
self.assertTrue('FakeQuantWithMinMaxVars' in
@@ -198,9 +231,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
- # Ensure that the bypass node is preceded and followed by
- # FakeQuantWithMinMaxVars operations.
- self.assertTrue('FakeQuantWithMinMaxVars' in
+ # Ensure that the bypass node is preceded by a FakeQuantWithMinMaxVar
+ # operation, and NOT followed by one.
+ self.assertTrue('FakeQuantWithMinMaxVars' not in
[c.type for c in bypass_tensor.consumers()])
self.assertTrue('FakeQuantWithMinMaxVars' in
[i.op.type for i in bypass_tensor.op.inputs])