aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-04-09 10:48:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 10:51:47 -0700
commit7576a99c49679dc17ff806acb1a5150f5d16ee58 (patch)
tree005681689fdcaf4c46a03a6ca605df415fed77fc /tensorflow/contrib/quantize
parent1ad181b6334ec339ab823cd122e19b7a1ad1a6f7 (diff)
Add `scope` parameter in experimental Quantization API.
This enables quantizing subgraphs of the entire graph. It's useful for networks like Inception since we don't want to quantize the AuxLogits scope. PiperOrigin-RevId: 192150416
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py70
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py26
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py110
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py30
4 files changed, 208 insertions, 28 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index d53d4d7b10..d2d0426d23 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
# Quantizable operation types that are supported by the quantization rewrite.
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
@@ -41,9 +42,16 @@ def Quantize(graph,
activation_bits=8,
ema_decay=0.999,
quant_delay=None,
- vars_collection=ops.GraphKeys.GLOBAL_VARIABLES):
+ vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=None):
"""Updates graph with quantization operations.
+ Currently we quantize the following tensors:
+ * Conv/MatMul: Quantize the weights if it matches.
+ * Activation: Quantize the output if it matches.
+ * Bypass/Post-activation Bypass: Quantize both input and output
+ if it matches.
+
Args:
graph: Graph to modify.
is_training: Whether quantizing training graph or eval graph.
@@ -57,13 +65,21 @@ def Quantize(graph,
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: When quantization fails.
"""
+ if scope and not scope.endswith('/'):
+ scope += '/'
+
input_to_ops_map = input_to_ops.InputToOps(graph)
for layer_match in _FindLayersToQuantize(graph):
# Quantize the weights.
context = _GetContextFromOp(layer_match.layer_op)
+
+ # If `scope` is given, only quantize it if the consumer of weights
+ # (the layer op) is in the right scope.
_InsertQuantOp(
context,
'weights_quant',
@@ -74,7 +90,8 @@ def Quantize(graph,
quant_delay=quant_delay,
narrow_range=True,
vars_collection=vars_collection,
- bits=weight_bits)
+ bits=weight_bits,
+ consumer_scope=scope)
# Quantize the activations.
consumer_ops = input_to_ops_map.ConsumerOperations(
@@ -82,6 +99,9 @@ def Quantize(graph,
add_context = context
if layer_match.bypass_op:
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
+
+ # If `scope` is given, only quantize it if the producer of weights
+ # (usually it's the layer op) is in the right scope.
_InsertQuantOp(
add_context,
'act_quant',
@@ -93,11 +113,14 @@ def Quantize(graph,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
- init_min=0.0)
+ init_min=0.0,
+ producer_scope=scope)
# Quantize the inputs and output to the bypass (if it exists). The input to
# the bypass is the bias add, and the output is the activation.
if layer_match.bypass_op is not None:
+ # If `scope` is given, only quantize it if the both the producer and the
+ # consumer are in the right scope.
_InsertQuantOp(
context,
'conv_quant',
@@ -107,7 +130,9 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
_InsertQuantOp(
add_context,
'add_quant',
@@ -118,12 +143,16 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
post_activation_bypass_context = re.search(
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ # If `scope` is given, only quantize it if the producer is in the right
+ # scope.
_InsertQuantOp(
post_activation_bypass_context,
'post_activation_bypass_quant',
@@ -135,7 +164,8 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope)
def _FindLayersToQuantize(graph):
@@ -382,7 +412,9 @@ def _InsertQuantOp(context,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
- narrow_range=False):
+ narrow_range=False,
+ producer_scope=None,
+ consumer_scope=None):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
@@ -407,10 +439,34 @@ def _InsertQuantOp(context,
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
+ producer_scope: The restriction of producer scope. If not None, the new op
+ will be inserted only when the producer is in this scope.
+ consumer_scope: The restriction of producer scope. If not None, the new op
+ will be inserted only when all the consumers are in this scope.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
+ if producer_scope and not producer.name.startswith(producer_scope):
+ logging.info(
+ '_InsertQuantOp ignores context="%s" name="%s" '
+ 'because producer "%s" is not in scope "%s"',
+ context, name, producer.name, producer_scope)
+ return
+
+ if consumer_scope:
+ consumers_in_scope = []
+ for consumer in consumers:
+ if consumer.name.startswith(consumer_scope):
+ consumers_in_scope.append(consumer)
+ else:
+ logging.info(
+ '_InsertQuantOp context="%s" name="%s" ignores '
+ 'consumer "%s" because it is not in scope "%s"',
+ context, name, consumer.name, consumer_scope)
+ return
+ consumers = consumers_in_scope
+
name_prefix = _AddContextToName(context, name)
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 0b74b438ac..11d052d7f4 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -28,7 +28,8 @@ def _create_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
quant_delay=None,
- freeze_bn_delay=None):
+ freeze_bn_delay=None,
+ scope=None):
"""Rewrites an input_graph in place for simulated quantization.
The graph has fake quantization ops inserted to simulate the error
@@ -48,6 +49,8 @@ def _create_graph(input_graph=None,
frozen and used instead of batch statistics during training.
freeze_bn_delay should be greater than quant_delay and should correspond
to the number of steps when training has almost converged
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -66,7 +69,8 @@ def _create_graph(input_graph=None,
is_training,
quant_delay=quant_delay,
weight_bits=weight_bits,
- activation_bits=activation_bits)
+ activation_bits=activation_bits,
+ scope=scope)
def create_training_graph(input_graph=None, quant_delay=0):
@@ -133,7 +137,8 @@ def experimental_create_training_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
quant_delay=0,
- freeze_bn_delay=None):
+ freeze_bn_delay=None,
+ scope=None):
"""Rewrites a training input_graph in place for simulated quantization.
Variables added by the rewrite get added to the global variables collection.
@@ -165,6 +170,8 @@ def experimental_create_training_graph(input_graph=None,
frozen and used instead of batch statistics during training.
freeze_bn_delay should be greater than quant_delay and should correspond
to when training has almost converged
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -177,12 +184,14 @@ def experimental_create_training_graph(input_graph=None,
weight_bits=weight_bits,
activation_bits=activation_bits,
quant_delay=quant_delay,
- freeze_bn_delay=freeze_bn_delay)
+ freeze_bn_delay=freeze_bn_delay,
+ scope=scope)
def experimental_create_eval_graph(input_graph=None,
weight_bits=8,
- activation_bits=8):
+ activation_bits=8,
+ scope=None):
"""Rewrites an eval input_graph in place for simulated quantization.
Variables added by the rewrite get added to the global variables collection.
@@ -200,8 +209,8 @@ def experimental_create_eval_graph(input_graph=None,
default graph.
weight_bits: Number of bits to use for quantizing weights.
activation_bits: Number of bits to use for quantizing activations.
-
-
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -211,4 +220,5 @@ def experimental_create_eval_graph(input_graph=None,
input_graph=input_graph,
is_training=False,
weight_bits=weight_bits,
- activation_bits=activation_bits)
+ activation_bits=activation_bits,
+ scope=scope)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index b9d03c1bc0..caf8ff28d5 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -66,6 +66,20 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
for fn in rewrite_fns:
test_fn(fn)
+ def _RunTestOverExperimentalRewritesWithScope(self, test_fn, scope):
+ def with_absent_scope(fn):
+ def fn_with_absent_scope(*args):
+ fn(*args, scope=scope)
+ return fn_with_absent_scope
+ rewrite_fns = [
+ with_absent_scope(
+ quantize_graph.experimental_create_training_graph),
+ with_absent_scope(
+ quantize_graph.experimental_create_eval_graph),
+ ]
+ for fn in rewrite_fns:
+ test_fn(fn)
+
def testRewrite(self):
self._RunTestOverAllRewrites(self._TestRewrite)
@@ -99,6 +113,34 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
+ def testWithPreActivationBypass(self):
+ self._RunTestOverAllRewrites(self._TestWithPreActivationBypass)
+
+ def _TestWithPreActivationBypass(self, rewrite_fn):
+ # Tests that the default graph is correctly used when no args are provided
+ # to rewrite_fn.
+ with ops.Graph().as_default() as g:
+ self._ConvLayer(pre_activation_bypass=True, scope='scope1')
+ rewrite_fn()
+
+ op_names = [op.name for op in g.get_operations()]
+ self.assertTrue(
+ any('scope1/add_quant/' in name for name in op_names))
+
+ def testWithPostActivationBypass(self):
+ self._RunTestOverAllRewrites(self._TestWithPostActivationBypass)
+
+ def _TestWithPostActivationBypass(self, rewrite_fn):
+ # Tests that the default graph is correctly used when no args are provided
+ # to rewrite_fn.
+ with ops.Graph().as_default() as g:
+ self._ConvLayer(post_activation_bypass=True, scope='scope1')
+ rewrite_fn()
+
+ op_names = [op.name for op in g.get_operations()]
+ self.assertTrue(any(
+ 'scope1/post_activation_bypass_quant/' in name for name in op_names))
+
def testQuantDelay(self):
self._RunTestOverTrainingRewrites(self._TestQuantDelay)
@@ -224,20 +266,66 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
graph_def_after = str(g.as_graph_def())
self.assertEqual(graph_def_before, graph_def_after)
- def _ConvLayer(self):
+ def testRewriteWithScope(self):
+ self._RunTestOverExperimentalRewritesWithScope(
+ self._TestRewriteWithScope, 'scope1')
+
+ def _TestRewriteWithScope(self, rewrite_fn):
+ graph = ops.Graph()
+ with graph.as_default():
+ scope1_output = self._ConvLayer(scope='scope1')
+ self._ConvLayer(input_tensor=scope1_output, scope='scope2')
+
+ rewrite_fn(graph)
+
+ op_names = [op.name for op in graph.get_operations()]
+ # The weights and activation of scope1 is quantized, but not scope2.
+ self.assertTrue(
+ any('scope1/Conv/act_quant' in name for name in op_names))
+ self.assertTrue(
+ any('scope1/Conv/weights_quant' in name for name in op_names))
+ self.assertFalse(
+ any('scope2/Conv/act_quant' in name for name in op_names))
+ self.assertFalse(
+ any('scope2/Conv/weights_quant' in name for name in op_names))
+
+ def testRewriteWithNonMatchingScope(self):
+ self._RunTestOverExperimentalRewritesWithScope(
+ self._TestRewriteWithNonMatchingScope, 'NonExistingScope')
+
+ def _TestRewriteWithNonMatchingScope(self, rewrite_fn):
+ graph = ops.Graph()
+ with graph.as_default():
+ self._ConvLayer()
+
+ op_names_before_rewrite = set([op.name for op in graph.get_operations()])
+ rewrite_fn(graph)
+ op_names_after_rewrite = set([op.name for op in graph.get_operations()])
+
+ # No ops should be inserted or removed.
+ self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+
+ def _ConvLayer(
+ self, input_tensor=None, scope='test', pre_activation_bypass=False,
+ post_activation_bypass=False):
"""Add a basic convolution layer to the default graph."""
batch_size, height, width, depth = 5, 128, 128, 3
- inputs = array_ops.zeros((batch_size, height, width, depth))
+ if input_tensor is None:
+ input_tensor = array_ops.zeros((batch_size, height, width, depth))
weight_init = init_ops.truncated_normal_initializer
- conv = layers.conv2d(
- inputs,
- 32, [5, 5],
- stride=2,
- padding='SAME',
- weights_initializer=weight_init(0.09),
- activation_fn=None,
- scope='test')
- _ = nn_ops.relu6(conv)
+ with ops.name_scope(scope):
+ output = layers.conv2d(
+ input_tensor,
+ depth, [5, 5],
+ padding='SAME',
+ weights_initializer=weight_init(0.09),
+ activation_fn=None)
+ if pre_activation_bypass:
+ output += input_tensor
+ output = nn_ops.relu6(output)
+ if post_activation_bypass:
+ output += input_tensor
+ return output
if __name__ == '__main__':
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 8d057d3710..d37c83d683 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -254,12 +254,11 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph = ops.Graph()
with graph.as_default():
with graph.name_scope(None):
- batch_size, height, width, depth = 5, 128, 128, 3
+ batch_size, height, width, depth = 5, 128, 128, 32
input1 = array_ops.zeros((batch_size, height, width, depth))
_ = conv2d(
input1,
32, [5, 5],
- stride=2,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
@@ -268,6 +267,33 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
# Passes if Quantize() does not crash.
+ def testWithNonMatchingNameScope(self):
+ self._RunTestOverParameters(self._testWithNonMatchingNameScope)
+
+ def _testWithNonMatchingNameScope(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ with graph.name_scope('name_scope'):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ _ = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+
+ op_names_before_quantize = set([op.name for op in graph.get_operations()])
+ quantize.Quantize(
+ graph, is_training, weight_bits=8, activation_bits=8,
+ scope='NonExisting/')
+ op_names_after_quantize = set([op.name for op in graph.get_operations()])
+
+ # No ops should be inserted or removed.
+ self.assertEqual(op_names_before_quantize, op_names_after_quantize)
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.