From 3015655fa4458bbc65222929f7b8f0ae0af4dd34 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Fri, 3 Nov 2017 12:30:37 -0700 Subject: Add regularizer support for fused batch norm. PiperOrigin-RevId: 174497943 --- tensorflow/contrib/layers/python/layers/layers.py | 43 +++++++++++++--------- .../contrib/layers/python/layers/layers_test.py | 20 ++++++++-- tensorflow/python/layers/normalization.py | 7 +--- 3 files changed, 44 insertions(+), 26 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 78c1839e51..ad4a0b302f 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -198,23 +198,23 @@ def avg_pool3d(inputs, return utils.collect_named_outputs(outputs_collections, sc, outputs) -def _fused_batch_norm( - inputs, - decay=0.999, - center=True, - scale=False, - epsilon=0.001, - activation_fn=None, - param_initializers=None, - updates_collections=ops.GraphKeys.UPDATE_OPS, - is_training=True, - reuse=None, - variables_collections=None, - outputs_collections=None, - trainable=True, - data_format=DATA_FORMAT_NHWC, - zero_debias_moving_mean=False, - scope=None): +def _fused_batch_norm(inputs, + decay=0.999, + center=True, + scale=False, + epsilon=0.001, + activation_fn=None, + param_initializers=None, + param_regularizers=None, + updates_collections=ops.GraphKeys.UPDATE_OPS, + is_training=True, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + data_format=DATA_FORMAT_NHWC, + zero_debias_moving_mean=False, + scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -257,6 +257,7 @@ def _fused_batch_norm( maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. + param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are @@ -324,6 +325,11 @@ def _fused_batch_norm( 'beta') if not param_initializers: param_initializers = {} + if not param_regularizers: + param_regularizers = {} + beta_regularizer = param_regularizers.get('beta') + gamma_regularizer = param_regularizers.get('gamma') + if center: beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) @@ -332,6 +338,7 @@ def _fused_batch_norm( shape=params_shape, dtype=dtype, initializer=beta_initializer, + regularizer=beta_regularizer, collections=beta_collections, trainable=trainable_beta) else: @@ -347,6 +354,7 @@ def _fused_batch_norm( shape=params_shape, dtype=dtype, initializer=gamma_initializer, + regularizer=gamma_regularizer, collections=gamma_collections, trainable=trainable) else: @@ -596,6 +604,7 @@ def batch_norm(inputs, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, + param_regularizers=param_regularizers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 7c77e905f7..2837a3172d 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1784,29 +1784,41 @@ class BatchNormTest(test.TestCase): def testCreateOpFused(self): self._testCreateOp(True) - def testCreateOpBetaRegularizer(self): + def _testCreateOpBetaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') - _layers.batch_norm(images, param_regularizers={'beta': reg}) + _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) beta_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(beta_decay.op.name, 'BatchNorm/beta/Regularizer/mul') - def testCreateOpGammaRegularizer(self): + def testCreateOpBetaRegularizerFused(self): + self._testCreateOpBetaRegularizer(fused=True) + + def testCreateOpBetaRegularizerNonFused(self): + self._testCreateOpBetaRegularizer(fused=False) + + def _testCreateOpGammaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') _layers.batch_norm( - images, param_regularizers={'gamma': reg}, scale=True) + images, param_regularizers={'gamma': reg}, scale=True, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) gamma_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(gamma_decay.op.name, 'BatchNorm/gamma/Regularizer/mul') + def testCreateOpGammaRegularizerFused(self): + self._testCreateOpGammaRegularizer(fused=True) + + def testCreateOpGammaRegularizerNonFused(self): + self._testCreateOpGammaRegularizer(fused=False) + def testCreateVariables(self): height, width = 3, 3 with self.test_session(): diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 01f56abc70..a9d59b25a3 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -211,16 +211,13 @@ class BatchNormalization(base.Layer): 'be specified') if self.fused: - # Currently fused batch norm doesn't support renorm and beta/gamma - # regularizer; and only supports an input tensor of rank 4 and a channel - # dimension on axis 1 and 3. + # Currently fused batch norm doesn't support renorm. It also only supports + # an input tensor of rank 4 and a channel dimension on axis 1 or 3. # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. self.fused = (not self.renorm and ndims == 4 and self.axis in [[1], [3]] and - self.beta_regularizer is None and - self.gamma_regularizer is None and self.virtual_batch_size is None and self.adjustment is None) # TODO(chrisying): fused batch norm is currently not supported for -- cgit v1.2.3