diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-04-09 10:48:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-09 10:51:47 -0700 |
commit | 7576a99c49679dc17ff806acb1a5150f5d16ee58 (patch) | |
tree | 005681689fdcaf4c46a03a6ca605df415fed77fc /tensorflow/contrib/quantize | |
parent | 1ad181b6334ec339ab823cd122e19b7a1ad1a6f7 (diff) |
Add `scope` parameter in experimental Quantization API.
This enables quantizing subgraphs of the entire graph. It's useful for networks
like Inception since we don't want to quantize the AuxLogits scope.
PiperOrigin-RevId: 192150416
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 70 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 110 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_test.py | 30 |
4 files changed, 208 insertions, 28 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index d53d4d7b10..d2d0426d23 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -27,6 +27,7 @@ from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging # Quantizable operation types that are supported by the quantization rewrite. _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} @@ -41,9 +42,16 @@ def Quantize(graph, activation_bits=8, ema_decay=0.999, quant_delay=None, - vars_collection=ops.GraphKeys.GLOBAL_VARIABLES): + vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, + scope=None): """Updates graph with quantization operations. + Currently we quantize the following tensors: + * Conv/MatMul: Quantize the weights if it matches. + * Activation: Quantize the output if it matches. + * Bypass/Post-activation Bypass: Quantize both input and output + if it matches. + Args: graph: Graph to modify. is_training: Whether quantizing training graph or eval graph. @@ -57,13 +65,21 @@ def Quantize(graph, training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. + scope: The scope to be transformed. If it's not None, only the ops which + are in this scope will be transformed. Raises: ValueError: When quantization fails. """ + if scope and not scope.endswith('/'): + scope += '/' + input_to_ops_map = input_to_ops.InputToOps(graph) for layer_match in _FindLayersToQuantize(graph): # Quantize the weights. context = _GetContextFromOp(layer_match.layer_op) + + # If `scope` is given, only quantize it if the consumer of weights + # (the layer op) is in the right scope. _InsertQuantOp( context, 'weights_quant', @@ -74,7 +90,8 @@ def Quantize(graph, quant_delay=quant_delay, narrow_range=True, vars_collection=vars_collection, - bits=weight_bits) + bits=weight_bits, + consumer_scope=scope) # Quantize the activations. consumer_ops = input_to_ops_map.ConsumerOperations( @@ -82,6 +99,9 @@ def Quantize(graph, add_context = context if layer_match.bypass_op: add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + + # If `scope` is given, only quantize it if the producer of weights + # (usually it's the layer op) is in the right scope. _InsertQuantOp( add_context, 'act_quant', @@ -93,11 +113,14 @@ def Quantize(graph, quant_delay=quant_delay, vars_collection=vars_collection, bits=activation_bits, - init_min=0.0) + init_min=0.0, + producer_scope=scope) # Quantize the inputs and output to the bypass (if it exists). The input to # the bypass is the bias add, and the output is the activation. if layer_match.bypass_op is not None: + # If `scope` is given, only quantize it if the both the producer and the + # consumer are in the right scope. _InsertQuantOp( context, 'conv_quant', @@ -107,7 +130,9 @@ def Quantize(graph, ema_decay=ema_decay, quant_delay=quant_delay, vars_collection=vars_collection, - bits=activation_bits) + bits=activation_bits, + producer_scope=scope, + consumer_scope=scope) _InsertQuantOp( add_context, 'add_quant', @@ -118,12 +143,16 @@ def Quantize(graph, ema_decay=ema_decay, quant_delay=quant_delay, vars_collection=vars_collection, - bits=activation_bits) + bits=activation_bits, + producer_scope=scope, + consumer_scope=scope) # Quantize bypass ops that occur after the activation. if layer_match.post_activation_bypass_op is not None: post_activation_bypass_context = re.search( r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1) + # If `scope` is given, only quantize it if the producer is in the right + # scope. _InsertQuantOp( post_activation_bypass_context, 'post_activation_bypass_quant', @@ -135,7 +164,8 @@ def Quantize(graph, ema_decay=ema_decay, quant_delay=quant_delay, vars_collection=vars_collection, - bits=activation_bits) + bits=activation_bits, + producer_scope=scope) def _FindLayersToQuantize(graph): @@ -382,7 +412,9 @@ def _InsertQuantOp(context, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, - narrow_range=False): + narrow_range=False, + producer_scope=None, + consumer_scope=None): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: @@ -407,10 +439,34 @@ def _InsertQuantOp(context, quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. + producer_scope: The restriction of producer scope. If not None, the new op + will be inserted only when the producer is in this scope. + consumer_scope: The restriction of producer scope. If not None, the new op + will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ + if producer_scope and not producer.name.startswith(producer_scope): + logging.info( + '_InsertQuantOp ignores context="%s" name="%s" ' + 'because producer "%s" is not in scope "%s"', + context, name, producer.name, producer_scope) + return + + if consumer_scope: + consumers_in_scope = [] + for consumer in consumers: + if consumer.name.startswith(consumer_scope): + consumers_in_scope.append(consumer) + else: + logging.info( + '_InsertQuantOp context="%s" name="%s" ignores ' + 'consumer "%s" because it is not in scope "%s"', + context, name, consumer.name, consumer_scope) + return + consumers = consumers_in_scope + name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index 0b74b438ac..11d052d7f4 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -28,7 +28,8 @@ def _create_graph(input_graph=None, weight_bits=8, activation_bits=8, quant_delay=None, - freeze_bn_delay=None): + freeze_bn_delay=None, + scope=None): """Rewrites an input_graph in place for simulated quantization. The graph has fake quantization ops inserted to simulate the error @@ -48,6 +49,8 @@ def _create_graph(input_graph=None, frozen and used instead of batch statistics during training. freeze_bn_delay should be greater than quant_delay and should correspond to the number of steps when training has almost converged + scope: The scope to be transformed. If it's not None, only the ops which + are in this scope will be transformed. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or @@ -66,7 +69,8 @@ def _create_graph(input_graph=None, is_training, quant_delay=quant_delay, weight_bits=weight_bits, - activation_bits=activation_bits) + activation_bits=activation_bits, + scope=scope) def create_training_graph(input_graph=None, quant_delay=0): @@ -133,7 +137,8 @@ def experimental_create_training_graph(input_graph=None, weight_bits=8, activation_bits=8, quant_delay=0, - freeze_bn_delay=None): + freeze_bn_delay=None, + scope=None): """Rewrites a training input_graph in place for simulated quantization. Variables added by the rewrite get added to the global variables collection. @@ -165,6 +170,8 @@ def experimental_create_training_graph(input_graph=None, frozen and used instead of batch statistics during training. freeze_bn_delay should be greater than quant_delay and should correspond to when training has almost converged + scope: The scope to be transformed. If it's not None, only the ops which + are in this scope will be transformed. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or @@ -177,12 +184,14 @@ def experimental_create_training_graph(input_graph=None, weight_bits=weight_bits, activation_bits=activation_bits, quant_delay=quant_delay, - freeze_bn_delay=freeze_bn_delay) + freeze_bn_delay=freeze_bn_delay, + scope=scope) def experimental_create_eval_graph(input_graph=None, weight_bits=8, - activation_bits=8): + activation_bits=8, + scope=None): """Rewrites an eval input_graph in place for simulated quantization. Variables added by the rewrite get added to the global variables collection. @@ -200,8 +209,8 @@ def experimental_create_eval_graph(input_graph=None, default graph. weight_bits: Number of bits to use for quantizing weights. activation_bits: Number of bits to use for quantizing activations. - - + scope: The scope to be transformed. If it's not None, only the ops which + are in this scope will be transformed. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or @@ -211,4 +220,5 @@ def experimental_create_eval_graph(input_graph=None, input_graph=input_graph, is_training=False, weight_bits=weight_bits, - activation_bits=activation_bits) + activation_bits=activation_bits, + scope=scope) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index b9d03c1bc0..caf8ff28d5 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -66,6 +66,20 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for fn in rewrite_fns: test_fn(fn) + def _RunTestOverExperimentalRewritesWithScope(self, test_fn, scope): + def with_absent_scope(fn): + def fn_with_absent_scope(*args): + fn(*args, scope=scope) + return fn_with_absent_scope + rewrite_fns = [ + with_absent_scope( + quantize_graph.experimental_create_training_graph), + with_absent_scope( + quantize_graph.experimental_create_eval_graph), + ] + for fn in rewrite_fns: + test_fn(fn) + def testRewrite(self): self._RunTestOverAllRewrites(self._TestRewrite) @@ -99,6 +113,34 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) + def testWithPreActivationBypass(self): + self._RunTestOverAllRewrites(self._TestWithPreActivationBypass) + + def _TestWithPreActivationBypass(self, rewrite_fn): + # Tests that the default graph is correctly used when no args are provided + # to rewrite_fn. + with ops.Graph().as_default() as g: + self._ConvLayer(pre_activation_bypass=True, scope='scope1') + rewrite_fn() + + op_names = [op.name for op in g.get_operations()] + self.assertTrue( + any('scope1/add_quant/' in name for name in op_names)) + + def testWithPostActivationBypass(self): + self._RunTestOverAllRewrites(self._TestWithPostActivationBypass) + + def _TestWithPostActivationBypass(self, rewrite_fn): + # Tests that the default graph is correctly used when no args are provided + # to rewrite_fn. + with ops.Graph().as_default() as g: + self._ConvLayer(post_activation_bypass=True, scope='scope1') + rewrite_fn() + + op_names = [op.name for op in g.get_operations()] + self.assertTrue(any( + 'scope1/post_activation_bypass_quant/' in name for name in op_names)) + def testQuantDelay(self): self._RunTestOverTrainingRewrites(self._TestQuantDelay) @@ -224,20 +266,66 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): graph_def_after = str(g.as_graph_def()) self.assertEqual(graph_def_before, graph_def_after) - def _ConvLayer(self): + def testRewriteWithScope(self): + self._RunTestOverExperimentalRewritesWithScope( + self._TestRewriteWithScope, 'scope1') + + def _TestRewriteWithScope(self, rewrite_fn): + graph = ops.Graph() + with graph.as_default(): + scope1_output = self._ConvLayer(scope='scope1') + self._ConvLayer(input_tensor=scope1_output, scope='scope2') + + rewrite_fn(graph) + + op_names = [op.name for op in graph.get_operations()] + # The weights and activation of scope1 is quantized, but not scope2. + self.assertTrue( + any('scope1/Conv/act_quant' in name for name in op_names)) + self.assertTrue( + any('scope1/Conv/weights_quant' in name for name in op_names)) + self.assertFalse( + any('scope2/Conv/act_quant' in name for name in op_names)) + self.assertFalse( + any('scope2/Conv/weights_quant' in name for name in op_names)) + + def testRewriteWithNonMatchingScope(self): + self._RunTestOverExperimentalRewritesWithScope( + self._TestRewriteWithNonMatchingScope, 'NonExistingScope') + + def _TestRewriteWithNonMatchingScope(self, rewrite_fn): + graph = ops.Graph() + with graph.as_default(): + self._ConvLayer() + + op_names_before_rewrite = set([op.name for op in graph.get_operations()]) + rewrite_fn(graph) + op_names_after_rewrite = set([op.name for op in graph.get_operations()]) + + # No ops should be inserted or removed. + self.assertEqual(op_names_before_rewrite, op_names_after_rewrite) + + def _ConvLayer( + self, input_tensor=None, scope='test', pre_activation_bypass=False, + post_activation_bypass=False): """Add a basic convolution layer to the default graph.""" batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) + if input_tensor is None: + input_tensor = array_ops.zeros((batch_size, height, width, depth)) weight_init = init_ops.truncated_normal_initializer - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=weight_init(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) + with ops.name_scope(scope): + output = layers.conv2d( + input_tensor, + depth, [5, 5], + padding='SAME', + weights_initializer=weight_init(0.09), + activation_fn=None) + if pre_activation_bypass: + output += input_tensor + output = nn_ops.relu6(output) + if post_activation_bypass: + output += input_tensor + return output if __name__ == '__main__': diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 8d057d3710..d37c83d683 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -254,12 +254,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph = ops.Graph() with graph.as_default(): with graph.name_scope(None): - batch_size, height, width, depth = 5, 128, 128, 3 + batch_size, height, width, depth = 5, 128, 128, 32 input1 = array_ops.zeros((batch_size, height, width, depth)) _ = conv2d( input1, 32, [5, 5], - stride=2, padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=None, @@ -268,6 +267,33 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) # Passes if Quantize() does not crash. + def testWithNonMatchingNameScope(self): + self._RunTestOverParameters(self._testWithNonMatchingNameScope) + + def _testWithNonMatchingNameScope(self, is_training): + graph = ops.Graph() + with graph.as_default(): + with graph.name_scope('name_scope'): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + + op_names_before_quantize = set([op.name for op in graph.get_operations()]) + quantize.Quantize( + graph, is_training, weight_bits=8, activation_bits=8, + scope='NonExisting/') + op_names_after_quantize = set([op.name for op in graph.get_operations()]) + + # No ops should be inserted or removed. + self.assertEqual(op_names_before_quantize, op_names_after_quantize) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. |