aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Raghuraman Krishnamoorthi <raghuramank@google.com>2018-09-20 14:31:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:39:13 -0700
commitf2d30a68169fc00ea444e5bffb2134f8fce92562 (patch)
tree44da2e37d963e027021858b8080c78c8712378db /tensorflow/contrib/quantize
parentbf5324fd55a894ac00d10b7cfb2d26f3d9f7f5c9 (diff)
Remove restriction on scope for bypass operators. Previously, the scope had to be of the form 'scope/<arbitrary_text>'. Relax restriction to handle empty scopes. Enable this change to work for both fused and unfused batch norm layers
PiperOrigin-RevId: 213883621
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/BUILD4
-rw-r--r--tensorflow/contrib/quantize/python/common.py4
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py59
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py94
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py15
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py282
6 files changed, 308 insertions, 150 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index c59f667f6a..23e3a25d71 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -20,9 +20,13 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":common",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index b27117dd48..e6c04bcf55 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -34,10 +34,10 @@ SKIPPED_PREFIXES = (
'ScalarSummary')
# Valid activation ops for quantization end points.
-_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
+_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
# Regular expression for recognizing nodes that are part of batch norm group.
-_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
+_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
def BatchNormGroups(graph):
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 2b26302f8a..a3ce041cea 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -13,21 +13,26 @@
# limitations under the License.
# ==============================================================================
"""Tests for common utilities in this package."""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+batch_norm = layers.batch_norm
+conv2d = layers.conv2d
+
class CommonTest(test_util.TensorFlowTestCase):
@@ -87,6 +92,56 @@ class CommonTest(test_util.TensorFlowTestCase):
for i in inputs:
self.assertIn(i, op.inputs)
+ def testBatchNormScope(self):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ stride = 1
+ out_depth = 32
+ scope = ''
+ node = conv2d(
+ inputs,
+ out_depth, [2, 2],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(False),
+ scope=scope)
+
+ node = nn_ops.relu(node, name='Relu6')
+ bn_list = common.BatchNormGroups(g)
+ with open('/tmp/common_test.pbtxt', 'w') as f:
+ f.write(str(g.as_graph_def()))
+
+ # Exactly one batch norm layer with empty scope should be found
+ self.assertEqual(len(bn_list), 1)
+ self.assertEqual(bn_list[0], '')
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ return params
+
+ def _WeightInit(self, stddev):
+ """Returns a truncated normal variable initializer.
+
+ Function is defined purely to shorten the name so that it stops wrapping.
+
+ Args:
+ stddev: Standard deviation of normal variable.
+
+ Returns:
+ An initializer that initializes with a truncated normal variable.
+ """
+ return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 2971b28f45..e5790a6e13 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -95,8 +95,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
_ComputeBatchNormCorrections(
context='',
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=True))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
@@ -296,8 +295,7 @@ def _FindFusedBatchNorms(graph):
batch_to_space_op=batch_to_space_op)
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
- fused_batch_norm):
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
"""Computes batch norm correction params.
Before batch normalization is frozen:
@@ -327,14 +325,14 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
computation.
freeze_batch_norm_delay: Delay in steps at which computation switches
from regular batch norm to frozen mean and variance.
- fused_batch_norm: Bool, true if fused batch norm is used.
+
Returns:
A tuple of correction_scale, correction_recip, correction_offset
"""
g = ops.get_default_graph()
- prefix = '' if not context else context + '/'
+ prefix = '' if not context else context
with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
@@ -495,8 +493,23 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
- add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
- add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+ # Changes to make sure that the correct scope is selected for the bypass add
+ # The rule here is that if the scope is of the form: str1/str2 for the
+ # batch norm,
+ # the bypass add is at scope str1. If bn is of scope just str1, then the
+ # bypass add is at scope ''.
+ # If there is no batch norm, then there is no bypass add.
+ add_bypass_ctx = ''
+ if bn:
+ try:
+ add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+ except AttributeError:
+ add_bypass_ctx = ''
+
+ if add_bypass_ctx:
+ add_bypass_ctx = add_bypass_ctx + '/'
+
+ add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
nodes_modified_count = common.RerouteTensor(
folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
@@ -505,8 +518,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -538,7 +551,8 @@ def _FindMatchingTensor(graph, match_pattern, scope):
if op.name.endswith(match_pattern):
split_name = op.name.split('/')
num_matches = len(set(split_name) & split_context)
- if num_matches > 0:
+
+ if num_matches > 0 or not scope:
match_dict[op.name] = num_matches
# match_dict contains matching op names from graph with values being
# number of matches to scope. We pick the key with the most matches
@@ -597,21 +611,21 @@ def _GetBatchNormParams(graph, context, has_scaling):
# op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
# will have 2 matches,scope with a different conv layer will have one match.
- op_suffix_mean = '/BatchNorm/moments/Squeeze'
- op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
- op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
- op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
+ op_suffix_mean = 'BatchNorm/moments/Squeeze'
+ op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
+ op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
+ op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
if variable_scope.get_variable_scope().use_resource:
- op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
+ op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
op_suffix_moving_variance = (
- '/BatchNorm/moving_variance/Read/ReadVariableOp')
- op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
+ 'BatchNorm/moving_variance/Read/ReadVariableOp')
+ op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
else:
- op_suffix_gamma = '/BatchNorm/gamma'
- op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
- op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
+ op_suffix_gamma = 'BatchNorm/gamma'
+ op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
# Parse through list of ops to find relevant ops
batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
@@ -679,8 +693,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
- mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm_1/' +
+ mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
# Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -697,8 +710,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
_ComputeBatchNormCorrections(
context=context,
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=False))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# Special handling for weights of depthwise convolution.
if op_below.type == 'DepthwiseConv2dNative':
new_shape = [
@@ -706,27 +718,27 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
weights.get_shape().as_list()[3]
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
- scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/' + scale_name)
+ scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
+ scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
- context + '/scale_reshape')
+ context + 'scale_reshape')
if correction_scale is not None:
correction_scale = array_ops.reshape(correction_scale, new_shape,
- context + '/correction_reshape')
+ context + 'correction_reshape')
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
+ context + 'correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
- (1, scale)])
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
+ (1, scale)])
elif op_below.type in ['Conv2D', 'MatMul']:
if correction_scale is not None:
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
+ context + 'correction_mult')
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
else:
raise ValueError('Cannot handle operation of type: %s' % op_below.type)
_AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
@@ -734,8 +746,8 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
[(1, mul_fold.outputs[0])])
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
# Copy the batch to space operation if we have a atrous convolution.
@@ -748,10 +760,10 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
if correction_offset is not None:
with ops.device(conv_or_fc_folded.device):
corrected_output = math_ops.multiply(correction_recip, corrected_output,
- context + '/post_conv_mul')
+ context + 'post_conv_mul')
corrected_output = math_ops.add(corrected_output, (correction_offset),
- context + '/correction_add')
- add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
+ context + 'correction_add')
+ add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
_AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
return add_shift, add_fold
@@ -930,7 +942,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0acd5..5e63d33db8 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -97,8 +97,11 @@ def Quantize(graph,
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
- add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
-
+ pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
+ if pattern_match_result is not None:
+ add_context = pattern_match_result.group(1)
+ else:
+ add_context = ''
# If `scope` is given, only quantize it if the producer of weights
# (usually it's the layer op) is in the right scope.
_InsertQuantOp(
@@ -156,8 +159,12 @@ def Quantize(graph,
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
- post_activation_bypass_context = re.search(
- r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ pattern_match_result = re.search(
+ r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
+ if pattern_match_result is not None:
+ post_activation_bypass_context = pattern_match_result.group(1)
+ else:
+ post_activation_bypass_context = ''
# If `scope` is given, only quantize it if the producer is in the right
# scope.
# Make sure the op following this isn't an activation. In which case, we
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 31a2955ddb..f6bf57a789 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -58,85 +58,102 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], False)
- test_fn(params[0], params[1], params[2], params[3], True)
+ test_fn(params[0], params[1], params[2], params[3], False, None)
+ test_fn(params[0], params[1], params[2], params[3], True, None)
+ # Test with both empty scope and an example scope
+ test_fn(params[0], params[1], params[2], params[3], False, 'test')
+ test_fn(params[0], params[1], params[2], params[3], True, 'test')
def _AssertCorrectQuantizedGraphWithoutBatchNorm(
self, graph, scope, layer, activation_op_name, with_bypass, delay,
use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
- weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
- quantization_node_name)
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+ weights_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
# Assemble the expected inputs.
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
else:
- expected_inputs.append(scope + '/' + layer + '/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
else:
expected_inputs = [
- scope + '/weights_quant/AssignMinLast',
- scope + '/weights_quant/AssignMaxLast',
+ conv_scope + delim + 'weights_quant/AssignMinLast',
+ conv_scope + delim + 'weights_quant/AssignMaxLast',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise_weights/read')
+ expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
else:
- expected_inputs.append(scope + '/weights/read')
+ expected_inputs.append(conv_scope + delim + 'weights/read')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if delay and delay > 0:
- output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
+ output_op_name = (
+ conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
else:
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + '/depthwise'
+ output_op_name = conv_scope + delim + 'depthwise'
else:
- output_op_name = scope + '/' + layer
+ output_op_name = conv_scope + delim + layer
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
- conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
- quantization_node_name)
+ conv_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- scope + '/BiasAdd',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim + 'BiasAdd',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
+ conv_scope + delim + 'BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
- output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
- if delay else 'test/Add')
+
+ output_op_name = (
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name('test/act_quant/' +
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- 'test/' + activation_op_name,
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + activation_op_name,
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
- 'test/' + activation_op_name
+ scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
+ scope + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -145,7 +162,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> Conv2d no batch norm -> Activation.
Args:
@@ -156,6 +174,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -165,7 +184,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -173,16 +194,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, True, quant_delay=delay)
+ if conv_scope is None:
+ conv_scope = ''
+
self._AssertCorrectQuantizedGraphWithoutBatchNorm(
graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
use_resource)
@@ -192,7 +216,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource, scope):
"""Tests quantization: inputs -> FC no batch norm -> Activation.
Args:
@@ -203,6 +227,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -211,16 +236,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ fc_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
- scope=scope)
+ scope=fc_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -235,7 +262,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ self, activation, activation_op_name, with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
Args:
@@ -246,6 +274,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -254,7 +283,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -263,10 +295,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -280,8 +312,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_AtrousConvWithoutBatchNorm)
- def _TestQuantize_AtrousConvWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
+ activation_op_name, with_bypass,
+ delay, use_resource, scope):
"""Tests quantization: inputs -> atrous conv no batch norm -> Activation.
Args:
@@ -292,6 +325,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -300,7 +334,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -309,10 +346,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -353,78 +390,96 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], params[4], False)
- test_fn(params[0], params[1], params[2], params[3], params[4], True)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ 'test')
+ test_fn(params[0], params[1], params[2], params[3], params[4], True,
+ 'test')
def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
activation_op_name, with_bypass,
delay, use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+
weights_quant = graph.get_operation_by_name(
- scope + '/weights_quant/' + quantization_node_name)
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
+
self.assertEqual(weights_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/weights_quant/' + 'AssignMinLast',
- scope + '/weights_quant/' + 'AssignMaxLast'
+ conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
+ conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
]
- expected_inputs.append(scope + '/mul_fold')
+ expected_inputs.append(conv_scope + delim + 'mul_fold')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/depthwise_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
else:
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/' + layer + '_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
conv_quant = graph.get_operation_by_name(
- scope + '/conv_quant/' + quantization_node_name)
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma',
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
]
- expected_inputs.append(scope + '/add_fold')
+ expected_inputs.append(conv_scope + delim + 'add_fold')
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (
- scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add')
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name(
- 'test/act_quant/' + quantization_node_name)
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
+ quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma',
- 'test/act_quant/AssignMaxEma',
+ scope + 'act_quant/AssignMinEma',
+ scope + 'act_quant/AssignMaxEma',
]
- expected_inputs.append('test/' + activation_op_name)
+ expected_inputs.append(scope + activation_op_name)
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -433,7 +488,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@@ -445,6 +500,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -453,7 +509,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -463,13 +521,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -487,7 +545,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@@ -499,6 +557,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -506,7 +565,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
@@ -514,13 +575,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -540,7 +601,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@@ -552,6 +613,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -559,7 +621,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -570,13 +634,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -595,7 +659,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_AtrousConvWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> atrous conv with batch norm -> Activation.
Args:
@@ -607,6 +671,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -614,7 +679,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -625,13 +693,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -718,6 +786,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
f.write(str(graph.as_graph_def()))
+ def _GetConvScope(self, scope, with_bypass):
+ if scope is None:
+ scope = ''
+ delim = '/' if scope else ''
+
+ if with_bypass:
+ conv_scope = scope + delim + 'test2'
+ else:
+ conv_scope = scope
+
+ return conv_scope
+
def _BatchNormParams(self, fused=False, force_updates=False):
params = {
'center': True,