diff options
author | Raghuraman Krishnamoorthi <raghuramank@google.com> | 2018-09-20 14:31:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:39:13 -0700 |
commit | f2d30a68169fc00ea444e5bffb2134f8fce92562 (patch) | |
tree | 44da2e37d963e027021858b8080c78c8712378db /tensorflow/contrib/quantize | |
parent | bf5324fd55a894ac00d10b7cfb2d26f3d9f7f5c9 (diff) |
Remove restriction on scope for bypass operators. Previously, the scope had to be of the form 'scope/<arbitrary_text>'. Relax restriction to handle empty scopes. Enable this change to work for both fused and unfused batch norm layers
PiperOrigin-RevId: 213883621
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/common.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/common_test.py | 59 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/fold_batch_norms.py | 94 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_parameterized_test.py | 282 |
6 files changed, 308 insertions, 150 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index c59f667f6a..23e3a25d71 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -20,9 +20,13 @@ py_test( srcs_version = "PY2AND3", deps = [ ":common", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index b27117dd48..e6c04bcf55 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -34,10 +34,10 @@ SKIPPED_PREFIXES = ( 'ScalarSummary') # Valid activation ops for quantization end points. -_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity'] +_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity'] # Regular expression for recognizing nodes that are part of batch norm group. -_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm') +_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm') def BatchNormGroups(graph): diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py index 2b26302f8a..a3ce041cea 100644 --- a/tensorflow/contrib/quantize/python/common_test.py +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -13,21 +13,26 @@ # limitations under the License. # ============================================================================== """Tests for common utilities in this package.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import common from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +batch_norm = layers.batch_norm +conv2d = layers.conv2d + class CommonTest(test_util.TensorFlowTestCase): @@ -87,6 +92,56 @@ class CommonTest(test_util.TensorFlowTestCase): for i in inputs: self.assertIn(i, op.inputs) + def testBatchNormScope(self): + batch_size, height, width, depth = 5, 128, 128, 3 + g = ops.Graph() + with g.as_default(): + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 + out_depth = 32 + scope = '' + node = conv2d( + inputs, + out_depth, [2, 2], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(False), + scope=scope) + + node = nn_ops.relu(node, name='Relu6') + bn_list = common.BatchNormGroups(g) + with open('/tmp/common_test.pbtxt', 'w') as f: + f.write(str(g.as_graph_def())) + + # Exactly one batch norm layer with empty scope should be found + self.assertEqual(len(bn_list), 1) + self.assertEqual(bn_list[0], '') + + def _BatchNormParams(self, fused=False, force_updates=False): + params = { + 'center': True, + 'scale': True, + 'decay': 1.0 - 0.003, + 'fused': fused + } + return params + + def _WeightInit(self, stddev): + """Returns a truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initializer that initializes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 2971b28f45..e5790a6e13 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -95,8 +95,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): _ComputeBatchNormCorrections( context='', match=match, - freeze_batch_norm_delay=freeze_batch_norm_delay, - fused_batch_norm=True)) + freeze_batch_norm_delay=freeze_batch_norm_delay)) # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. @@ -296,8 +295,7 @@ def _FindFusedBatchNorms(graph): batch_to_space_op=batch_to_space_op) -def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, - fused_batch_norm): +def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay): """Computes batch norm correction params. Before batch normalization is frozen: @@ -327,14 +325,14 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. - fused_batch_norm: Bool, true if fused batch norm is used. + Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() - prefix = '' if not context else context + '/' + prefix = '' if not context else context with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt( match.moving_variance_tensor + match.batch_epsilon) @@ -495,8 +493,23 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): # Treat consumer ops in bypass modules differently since they have Add # operations instead of Relu* above. - add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) - add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') + # Changes to make sure that the correct scope is selected for the bypass add + # The rule here is that if the scope is of the form: str1/str2 for the + # batch norm, + # the bypass add is at scope str1. If bn is of scope just str1, then the + # bypass add is at scope ''. + # If there is no batch norm, then there is no bypass add. + add_bypass_ctx = '' + if bn: + try: + add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) + except AttributeError: + add_bypass_ctx = '' + + if add_bypass_ctx: + add_bypass_ctx = add_bypass_ctx + '/' + + add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add') nodes_modified_count = common.RerouteTensor( folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass]) if nodes_modified_count != 1: @@ -505,8 +518,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): def _IsValidUnfusedBatchNorm(graph, context): """Checks that the output of the unfused batch norm has consumers.""" - add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm_1/add_1') + add_shift = graph.get_operation_by_name(context + + 'BatchNorm/batchnorm_1/add_1') # Ensure that the output tensor of batch norm has consumers, otherwise this # is a dangling node and not a match. return bool(add_shift.outputs[0].consumers()) @@ -538,7 +551,8 @@ def _FindMatchingTensor(graph, match_pattern, scope): if op.name.endswith(match_pattern): split_name = op.name.split('/') num_matches = len(set(split_name) & split_context) - if num_matches > 0: + + if num_matches > 0 or not scope: match_dict[op.name] = num_matches # match_dict contains matching op names from graph with values being # number of matches to scope. We pick the key with the most matches @@ -597,21 +611,21 @@ def _GetBatchNormParams(graph, context, has_scaling): # op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read # will have 2 matches,scope with a different conv layer will have one match. - op_suffix_mean = '/BatchNorm/moments/Squeeze' - op_suffix_variance = '/BatchNorm/moments/Squeeze_1' - op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y' - op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay' - op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay' + op_suffix_mean = 'BatchNorm/moments/Squeeze' + op_suffix_variance = 'BatchNorm/moments/Squeeze_1' + op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y' + op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay' + op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay' if variable_scope.get_variable_scope().use_resource: - op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp' + op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp' op_suffix_moving_variance = ( - '/BatchNorm/moving_variance/Read/ReadVariableOp') - op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp') + 'BatchNorm/moving_variance/Read/ReadVariableOp') + op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp') else: - op_suffix_gamma = '/BatchNorm/gamma' - op_suffix_moving_variance = '/BatchNorm/moving_variance/read' - op_suffix_moving_mean = '/BatchNorm/moving_mean/read' + op_suffix_gamma = 'BatchNorm/gamma' + op_suffix_moving_variance = 'BatchNorm/moving_variance/read' + op_suffix_moving_mean = 'BatchNorm/moving_mean/read' # Parse through list of ops to find relevant ops batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context) @@ -679,8 +693,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, the folded graph (add_fold). """ mul_scale_name = 'mul_1' if has_scaling else 'mul' - mul_scale = graph.get_operation_by_name(context + - '/BatchNorm/batchnorm_1/' + + mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op # Skip over the BatchToSpace operation in the case of atrous convolutions. @@ -697,8 +710,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, _ComputeBatchNormCorrections( context=context, match=match, - freeze_batch_norm_delay=freeze_batch_norm_delay, - fused_batch_norm=False)) + freeze_batch_norm_delay=freeze_batch_norm_delay)) # Special handling for weights of depthwise convolution. if op_below.type == 'DepthwiseConv2dNative': new_shape = [ @@ -706,27 +718,27 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, weights.get_shape().as_list()[3] ] scale_name = 'mul' if has_scaling else 'Rsqrt' - scale = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm_1/' + scale_name) + scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' + + scale_name) scale = array_ops.reshape(scale.outputs[0], new_shape, - context + '/scale_reshape') + context + 'scale_reshape') if correction_scale is not None: correction_scale = array_ops.reshape(correction_scale, new_shape, - context + '/correction_reshape') + context + 'correction_reshape') with ops.device(mul_scale.device): weights = math_ops.multiply(correction_scale, weights, - context + '/correction_mult') + context + 'correction_mult') - mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights), - (1, scale)]) + mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights), + (1, scale)]) elif op_below.type in ['Conv2D', 'MatMul']: if correction_scale is not None: with ops.device(mul_scale.device): weights = math_ops.multiply(correction_scale, weights, - context + '/correction_mult') - mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) + context + 'correction_mult') + mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)]) else: raise ValueError('Cannot handle operation of type: %s' % op_below.type) _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) @@ -734,8 +746,8 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', [(1, mul_fold.outputs[0])]) - add_shift = graph.get_operation_by_name( - context + '/BatchNorm/batchnorm_1/add_1') + add_shift = graph.get_operation_by_name(context + + 'BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] # Copy the batch to space operation if we have a atrous convolution. @@ -748,10 +760,10 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, if correction_offset is not None: with ops.device(conv_or_fc_folded.device): corrected_output = math_ops.multiply(correction_recip, corrected_output, - context + '/post_conv_mul') + context + 'post_conv_mul') corrected_output = math_ops.add(corrected_output, (correction_offset), - context + '/correction_add') - add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)]) + context + 'correction_add') + add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)]) _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) return add_shift, add_fold @@ -930,7 +942,7 @@ def _HasScaling(graph, input_to_ops_map, bn): Returns: A boolean indicating whether this batch norm layer has scaling enabled. """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt') + rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt') rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index e88db0acd5..5e63d33db8 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -97,8 +97,11 @@ def Quantize(graph, layer_match.activation_op) add_context = context if layer_match.bypass_op: - add_context = re.search(r'^(.*)/([^/]+)', context).group(1) - + pattern_match_result = re.search(r'^(.*)/([^/]+)', context) + if pattern_match_result is not None: + add_context = pattern_match_result.group(1) + else: + add_context = '' # If `scope` is given, only quantize it if the producer of weights # (usually it's the layer op) is in the right scope. _InsertQuantOp( @@ -156,8 +159,12 @@ def Quantize(graph, # Quantize bypass ops that occur after the activation. if layer_match.post_activation_bypass_op is not None: - post_activation_bypass_context = re.search( - r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1) + pattern_match_result = re.search( + r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name) + if pattern_match_result is not None: + post_activation_bypass_context = pattern_match_result.group(1) + else: + post_activation_bypass_context = '' # If `scope` is given, only quantize it if the producer is in the right # scope. # Make sure the op following this isn't an activation. In which case, we diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 31a2955ddb..f6bf57a789 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -58,85 +58,102 @@ class QuantizeTest(test_util.TensorFlowTestCase): ] for params in parameters_list: # Test everything with resource variables and normal variables. - test_fn(params[0], params[1], params[2], params[3], False) - test_fn(params[0], params[1], params[2], params[3], True) + test_fn(params[0], params[1], params[2], params[3], False, None) + test_fn(params[0], params[1], params[2], params[3], True, None) + # Test with both empty scope and an example scope + test_fn(params[0], params[1], params[2], params[3], False, 'test') + test_fn(params[0], params[1], params[2], params[3], True, 'test') def _AssertCorrectQuantizedGraphWithoutBatchNorm( self, graph, scope, layer, activation_op_name, with_bypass, delay, use_resource): quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) + conv_scope = self._GetConvScope(scope, with_bypass) + delim = '/' if conv_scope else '' + + if scope: + scope = scope + '/' + weights_quant = graph.get_operation_by_name( + conv_scope + delim + 'weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) # Assemble the expected inputs. if use_resource: expected_inputs = [ - scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + conv_scope + delim + + 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + conv_scope + delim + + 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', ] if layer == 'DepthwiseConv2dNative': - expected_inputs.append(scope + '/depthwise/ReadVariableOp') + expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp') else: - expected_inputs.append(scope + '/' + layer + '/ReadVariableOp') + expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp') else: expected_inputs = [ - scope + '/weights_quant/AssignMinLast', - scope + '/weights_quant/AssignMaxLast', + conv_scope + delim + 'weights_quant/AssignMinLast', + conv_scope + delim + 'weights_quant/AssignMaxLast', ] if layer == 'DepthwiseConv2dNative': - expected_inputs.append(scope + '/depthwise_weights/read') + expected_inputs.append(conv_scope + delim + 'depthwise_weights/read') else: - expected_inputs.append(scope + '/weights/read') + expected_inputs.append(conv_scope + delim + 'weights/read') self._AssertInputOpsAre(weights_quant, expected_inputs) if delay and delay > 0: - output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' + output_op_name = ( + conv_scope + delim + 'weights_quant/delayed_quant/Switch_1') else: if layer == 'DepthwiseConv2dNative': - output_op_name = scope + '/depthwise' + output_op_name = conv_scope + delim + 'depthwise' else: - output_op_name = scope + '/' + layer + output_op_name = conv_scope + delim + layer self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) + conv_quant = graph.get_operation_by_name( + conv_scope + delim + 'conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) if use_resource: expected_inputs = [ - scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', - scope + '/BiasAdd', + conv_scope + delim + + 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + conv_scope + delim + + 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + conv_scope + delim + 'BiasAdd', ] else: expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' + conv_scope + delim + 'conv_quant/AssignMinEma', + conv_scope + delim + 'conv_quant/AssignMaxEma', + conv_scope + delim + 'BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') + + output_op_name = ( + conv_scope + delim + 'conv_quant/delayed_quant/Switch_1' + if delay else scope + 'Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - act_quant = graph.get_operation_by_name('test/act_quant/' + + act_quant = graph.get_operation_by_name(scope + 'act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) if use_resource: expected_inputs = [ - 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', - 'test/' + activation_op_name, + scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + scope + activation_op_name, ] else: expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name + scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma', + scope + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') + output_op_name = ( + scope + 'act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertIdempotent(graph) @@ -145,7 +162,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._TestQuantize_Conv2dWithoutBatchNorm) def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_resource): + with_bypass, delay, use_resource, + scope): """Tests quantization: inputs -> Conv2d no batch norm -> Activation. Args: @@ -156,6 +174,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -165,7 +184,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 activation_fn = None if with_bypass else activation - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' node = conv2d( inputs, out_depth, [5, 5], @@ -173,16 +194,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, - scope=scope) + scope=conv_scope) if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - node = activation(node, name='test/' + activation_op_name) + node = math_ops.add(inputs, node, name=scope + delim + 'Add') + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, True, quant_delay=delay) + if conv_scope is None: + conv_scope = '' + self._AssertCorrectQuantizedGraphWithoutBatchNorm( graph, scope, 'Conv2D', activation_op_name, with_bypass, delay, use_resource) @@ -192,7 +216,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._TestQuantize_FCWithoutBatchNorm) def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_resource): + with_bypass, delay, use_resource, scope): """Tests quantization: inputs -> FC no batch norm -> Activation. Args: @@ -203,6 +227,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -211,16 +236,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 activation_fn = None if with_bypass else activation - scope = 'test/test2' if with_bypass else 'test' + fc_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' node = fully_connected( inputs, out_depth, weights_initializer=self._WeightInit(0.03), activation_fn=activation_fn, - scope=scope) + scope=fc_scope) if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - node = activation(node, name='test/' + activation_op_name) + node = math_ops.add(inputs, node, name=scope + delim + 'Add') + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') @@ -235,7 +262,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( - self, activation, activation_op_name, with_bypass, delay, use_resource): + self, activation, activation_op_name, with_bypass, delay, use_resource, + scope): """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. Args: @@ -246,6 +274,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -254,7 +283,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else activation - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' + node = separable_conv2d( inputs, None, [5, 5], @@ -263,10 +295,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, - scope=scope) + scope=conv_scope) if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - node = activation(node, name='test/' + activation_op_name) + node = math_ops.add(inputs, node, name=scope + delim + 'Add') + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') @@ -280,8 +312,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._RunWithoutBatchNormTestOverParameters( self._TestQuantize_AtrousConvWithoutBatchNorm) - def _TestQuantize_AtrousConvWithoutBatchNorm( - self, activation, activation_op_name, with_bypass, delay, use_resource): + def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation, + activation_op_name, with_bypass, + delay, use_resource, scope): """Tests quantization: inputs -> atrous conv no batch norm -> Activation. Args: @@ -292,6 +325,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -300,7 +334,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) dilation_rate = 2 activation_fn = None if with_bypass else activation - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' + node = separable_conv2d( inputs, None, [3, 3], @@ -309,10 +346,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): padding='SAME', weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, - scope=scope) + scope=conv_scope) if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - node = activation(node, name='test/' + activation_op_name) + node = math_ops.add(inputs, node, name=scope + delim + 'Add') + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') @@ -353,78 +390,96 @@ class QuantizeTest(test_util.TensorFlowTestCase): ] for params in parameters_list: # Test everything with resource variables and normal variables. - test_fn(params[0], params[1], params[2], params[3], params[4], False) - test_fn(params[0], params[1], params[2], params[3], params[4], True) + test_fn(params[0], params[1], params[2], params[3], params[4], False, + None) + test_fn(params[0], params[1], params[2], params[3], params[4], True, None) + test_fn(params[0], params[1], params[2], params[3], params[4], False, + 'test') + test_fn(params[0], params[1], params[2], params[3], params[4], True, + 'test') def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer, activation_op_name, with_bypass, delay, use_resource): quantization_node_name = 'FakeQuantWithMinMaxVars' + conv_scope = self._GetConvScope(scope, with_bypass) + delim = '/' if conv_scope else '' + + if scope: + scope = scope + '/' + weights_quant = graph.get_operation_by_name( - scope + '/weights_quant/' + quantization_node_name) + conv_scope + delim + 'weights_quant/' + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) if use_resource: expected_inputs = [ - scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + conv_scope + delim + + 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + conv_scope + delim + + 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', ] else: expected_inputs = [ - scope + '/weights_quant/' + 'AssignMinLast', - scope + '/weights_quant/' + 'AssignMaxLast' + conv_scope + delim + 'weights_quant/' + 'AssignMinLast', + conv_scope + delim + 'weights_quant/' + 'AssignMaxLast' ] - expected_inputs.append(scope + '/mul_fold') + expected_inputs.append(conv_scope + delim + 'mul_fold') self._AssertInputOpsAre(weights_quant, expected_inputs) if layer == 'DepthwiseConv2dNative': - output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay else '/depthwise_Fold') + output_op_name = conv_scope + delim + ( + 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold') else: - output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay else '/' + layer + '_Fold') + output_op_name = conv_scope + delim + ( + 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: conv_quant = graph.get_operation_by_name( - scope + '/conv_quant/' + quantization_node_name) + conv_scope + delim + 'conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) if use_resource: expected_inputs = [ - scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + conv_scope + delim + + 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + conv_scope + delim + + 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', ] else: expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', + conv_scope + delim + 'conv_quant/AssignMinEma', + conv_scope + delim + 'conv_quant/AssignMaxEma', ] - expected_inputs.append(scope + '/add_fold') + expected_inputs.append(conv_scope + delim + 'add_fold') self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = ( - scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') + conv_scope + delim + 'conv_quant/delayed_quant/Switch_1' + if delay else scope + 'Add') self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - act_quant = graph.get_operation_by_name( - 'test/act_quant/' + quantization_node_name) + act_quant = graph.get_operation_by_name(scope + 'act_quant/' + + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) if use_resource: expected_inputs = [ - 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', - 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', ] else: expected_inputs = [ - 'test/act_quant/AssignMinEma', - 'test/act_quant/AssignMaxEma', + scope + 'act_quant/AssignMinEma', + scope + 'act_quant/AssignMaxEma', ] - expected_inputs.append('test/' + activation_op_name) + expected_inputs.append(scope + activation_op_name) self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') + output_op_name = ( + scope + 'act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertIdempotent(graph) @@ -433,7 +488,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, - use_resource): + use_resource, scope): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -445,6 +500,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -453,7 +509,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' node = conv2d( inputs, out_depth, [5, 5], @@ -463,13 +521,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), - scope=scope) + scope=conv_scope) # Manually add a bypass (optional) and an activation. if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') + node = math_ops.add(inputs, node, name=scope + delim + 'Add') - node = activation(node, name='test/' + activation_op_name) + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): @@ -487,7 +545,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm, - use_resource): + use_resource, scope): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -499,6 +557,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -506,7 +565,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' node = fully_connected( inputs, out_depth, @@ -514,13 +575,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), - scope=scope) + scope=conv_scope) # Manually add a bypass (optional) and an activation. if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') + node = math_ops.add(inputs, node, name=scope + delim + 'Add') - node = activation(node, name='test/' + activation_op_name) + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): @@ -540,7 +601,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm, use_resource): + fused_batch_norm, use_resource, scope): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -552,6 +613,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -559,7 +621,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' node = separable_conv2d( inputs, None, [5, 5], @@ -570,13 +634,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), - scope=scope) + scope=conv_scope) # Manually add a bypass (optional) and an activation. if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') + node = math_ops.add(inputs, node, name=scope + delim + 'Add') - node = activation(node, name='test/' + activation_op_name) + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): @@ -595,7 +659,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): def _TestQuantize_AtrousConvWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm, use_resource): + fused_batch_norm, use_resource, scope): """Tests quantization: inputs -> atrous conv with batch norm -> Activation. Args: @@ -607,6 +671,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. use_resource: Bool, when true uses resource variables. + scope: String, specifies top level scope for the graph """ graph = ops.Graph() with graph.as_default(): @@ -614,7 +679,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) dilation_rate = 2 - scope = 'test/test2' if with_bypass else 'test' + conv_scope = self._GetConvScope(scope, with_bypass) + scope = '' if scope is None else scope + delim = '/' if scope else '' + node = separable_conv2d( inputs, None, [3, 3], @@ -625,13 +693,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, normalizer_fn=batch_norm, normalizer_params=self._BatchNormParams(fused_batch_norm), - scope=scope) + scope=conv_scope) # Manually add a bypass (optional) and an activation. if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') + node = math_ops.add(inputs, node, name=scope + delim + 'Add') - node = activation(node, name='test/' + activation_op_name) + node = activation(node, name=scope + delim + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): @@ -718,6 +786,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): with open('/tmp/bn_quant_test.pbtxt', 'w') as f: f.write(str(graph.as_graph_def())) + def _GetConvScope(self, scope, with_bypass): + if scope is None: + scope = '' + delim = '/' if scope else '' + + if with_bypass: + conv_scope = scope + delim + 'test2' + else: + conv_scope = scope + + return conv_scope + def _BatchNormParams(self, fused=False, force_updates=False): params = { 'center': True, |