aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Raghuraman Krishnamoorthi <raghuramank@google.com>2018-04-24 11:20:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 11:24:09 -0700
commit4a82acf286df1bc10581d91e13e0ab17458e83b4 (patch)
tree76606297e61f008130e825f1953103adaceb1680 /tensorflow/contrib/quantize
parentaeaec69869f13fc37c3ed28881741dd344e6a150 (diff)
Improve handling of scopes in folding unfused batch norms. This change allows folding to work for MobilenetV2 with unfused batch norms
PiperOrigin-RevId: 194116535
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py24
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py79
2 files changed, 100 insertions, 3 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index aa0ef64308..6f41722748 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -501,8 +501,27 @@ def _GetBatchNormParams(graph, context, has_scaling):
bn_decay_var_tensor = None
split_context = context.split('/')
- base_context = split_context[-1]
-
+ # Matching variable names is brittle and relies on scoping
+ # conventions. Fused batch norm folding is more robust. Support for unfused
+ # batch norms will be deprecated as we move forward. Fused batch norms allow
+ # for faster training and should be used whenever possible.
+ # context contains part of the names of the tensors we are interested in:
+ # For MobilenetV1, the context has repetitions:
+ # MobilenetV1/MobilenetV1/Conv2d_3_depthwise
+ # when the moving_mean tensor has the name:
+ # MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read
+ # To pick the correct variable name, it is necessary to ignore the repeating
+ # header.
+
+ # For MobilenetV2, this problem does not exist:
+ # The context is: MobilenetV2/expanded_conv_3/depthwise
+ # and the names of the tensors start with a single MobilenetV2
+ # The moving mean for example, has the name:
+ # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
+ # We ignore the first string (MobilenetV1 or MobilenetV2)
+ # in the context to match correctly in both cases
+
+ base_context = '/'.join(split_context[1:])
oplist = graph.get_operations()
op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze'
op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1'
@@ -520,7 +539,6 @@ def _GetBatchNormParams(graph, context, has_scaling):
op_suffix_gamma = base_context + '/BatchNorm/gamma'
op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read'
op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read'
-
# Parse through list of ops to find relevant ops
for op in oplist:
if op.name.endswith(op_suffix_mean):
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index af31467476..64e8142e7c 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -134,6 +134,85 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def testFoldConv2d(self):
self._RunTestOverParameters(self._TestFoldConv2d)
+ def testMultipleLayerConv2d(self,
+ relu=nn_ops.relu,
+ relu_op_name='Relu',
+ has_scaling=True,
+ fused_batch_norm=False,
+ freeze_batch_norm_delay=None):
+ """Tests folding cases for a network with multiple layers.
+
+ Args:
+ relu: Callable that returns an Operation, a factory method for the Relu*.
+ relu_op_name: String, name of the Relu* operation.
+ has_scaling: Bool, when true the batch norm has scaling.
+ fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
+ """
+ g = ops.Graph()
+ with g.as_default():
+ batch_size, height, width = 5, 128, 128
+ inputs = array_ops.zeros((batch_size, height, width, 3))
+ out_depth = 3
+ stride = 1
+ activation_fn = relu
+ scope = 'network/expanded_conv_1/conv'
+ layer1 = conv2d(
+ inputs,
+ out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope=scope)
+ # Add another layer
+ scope = 'network/expanded_conv_2/conv'
+
+ _ = conv2d(
+ layer1,
+ 2 * out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope=scope)
+
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
+ folded_mul = g.get_operation_by_name(scope + '/mul_fold')
+ self.assertEqual(folded_mul.type, 'Mul')
+ self._AssertInputOpsAre(folded_mul, [
+ scope + '/correction_mult',
+ self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
+ ])
+ self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
+
+ folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold')
+ self.assertEqual(folded_conv.type, 'Conv2D')
+ # Remove :0 at end of name for tensor prior to comparison
+ self._AssertInputOpsAre(folded_conv,
+ [scope + '/mul_fold', layer1.name[:-2]])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
+
+ folded_add = g.get_operation_by_name(scope + '/add_fold')
+ self.assertEqual(folded_add.type, 'Add')
+ self._AssertInputOpsAre(folded_add, [
+ scope + '/correction_add',
+ self._BathNormBiasName(scope, fused_batch_norm)
+ ])
+ output_op_names = [scope + '/' + relu_op_name]
+ self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm,
freeze_batch_norm_delay):