aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-11-03 12:30:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 12:33:57 -0700
commit3015655fa4458bbc65222929f7b8f0ae0af4dd34 (patch)
treea791907501b0c9a565dce35b1113a9b3bce48e19
parent46c1e3b362f2ee16f8476a5eaf7e952e44c1b653 (diff)
Add regularizer support for fused batch norm.
PiperOrigin-RevId: 174497943
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py43
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py20
-rw-r--r--tensorflow/python/layers/normalization.py7
3 files changed, 44 insertions, 26 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 78c1839e51..ad4a0b302f 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -198,23 +198,23 @@ def avg_pool3d(inputs,
return utils.collect_named_outputs(outputs_collections, sc, outputs)
-def _fused_batch_norm(
- inputs,
- decay=0.999,
- center=True,
- scale=False,
- epsilon=0.001,
- activation_fn=None,
- param_initializers=None,
- updates_collections=ops.GraphKeys.UPDATE_OPS,
- is_training=True,
- reuse=None,
- variables_collections=None,
- outputs_collections=None,
- trainable=True,
- data_format=DATA_FORMAT_NHWC,
- zero_debias_moving_mean=False,
- scope=None):
+def _fused_batch_norm(inputs,
+ decay=0.999,
+ center=True,
+ scale=False,
+ epsilon=0.001,
+ activation_fn=None,
+ param_initializers=None,
+ param_regularizers=None,
+ updates_collections=ops.GraphKeys.UPDATE_OPS,
+ is_training=True,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ data_format=DATA_FORMAT_NHWC,
+ zero_debias_moving_mean=False,
+ scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@@ -257,6 +257,7 @@ def _fused_batch_norm(
maintain a linear activation.
param_initializers: Optional initializers for beta, gamma, moving mean and
moving variance.
+ param_regularizers: Optional regularizer for beta and gamma.
updates_collections: Collections to collect the update ops for computation.
The updates_ops need to be executed with the train_op.
If None, a control dependency would be added to make sure the updates are
@@ -324,6 +325,11 @@ def _fused_batch_norm(
'beta')
if not param_initializers:
param_initializers = {}
+ if not param_regularizers:
+ param_regularizers = {}
+ beta_regularizer = param_regularizers.get('beta')
+ gamma_regularizer = param_regularizers.get('gamma')
+
if center:
beta_initializer = param_initializers.get('beta',
init_ops.zeros_initializer())
@@ -332,6 +338,7 @@ def _fused_batch_norm(
shape=params_shape,
dtype=dtype,
initializer=beta_initializer,
+ regularizer=beta_regularizer,
collections=beta_collections,
trainable=trainable_beta)
else:
@@ -347,6 +354,7 @@ def _fused_batch_norm(
shape=params_shape,
dtype=dtype,
initializer=gamma_initializer,
+ regularizer=gamma_regularizer,
collections=gamma_collections,
trainable=trainable)
else:
@@ -596,6 +604,7 @@ def batch_norm(inputs,
epsilon=epsilon,
activation_fn=activation_fn,
param_initializers=param_initializers,
+ param_regularizers=param_regularizers,
updates_collections=updates_collections,
is_training=is_training,
reuse=reuse,
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 7c77e905f7..2837a3172d 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1784,29 +1784,41 @@ class BatchNormTest(test.TestCase):
def testCreateOpFused(self):
self._testCreateOp(True)
- def testCreateOpBetaRegularizer(self):
+ def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
with self.test_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
- _layers.batch_norm(images, param_regularizers={'beta': reg})
+ _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
beta_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(beta_decay.op.name, 'BatchNorm/beta/Regularizer/mul')
- def testCreateOpGammaRegularizer(self):
+ def testCreateOpBetaRegularizerFused(self):
+ self._testCreateOpBetaRegularizer(fused=True)
+
+ def testCreateOpBetaRegularizerNonFused(self):
+ self._testCreateOpBetaRegularizer(fused=False)
+
+ def _testCreateOpGammaRegularizer(self, fused=True):
height, width = 3, 3
with self.test_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(
- images, param_regularizers={'gamma': reg}, scale=True)
+ images, param_regularizers={'gamma': reg}, scale=True, fused=fused)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1)
gamma_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(gamma_decay.op.name, 'BatchNorm/gamma/Regularizer/mul')
+ def testCreateOpGammaRegularizerFused(self):
+ self._testCreateOpGammaRegularizer(fused=True)
+
+ def testCreateOpGammaRegularizerNonFused(self):
+ self._testCreateOpGammaRegularizer(fused=False)
+
def testCreateVariables(self):
height, width = 3, 3
with self.test_session():
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 01f56abc70..a9d59b25a3 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -211,16 +211,13 @@ class BatchNormalization(base.Layer):
'be specified')
if self.fused:
- # Currently fused batch norm doesn't support renorm and beta/gamma
- # regularizer; and only supports an input tensor of rank 4 and a channel
- # dimension on axis 1 and 3.
+ # Currently fused batch norm doesn't support renorm. It also only supports
+ # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
self.fused = (not self.renorm and
ndims == 4 and
self.axis in [[1], [3]] and
- self.beta_regularizer is None and
- self.gamma_regularizer is None and
self.virtual_batch_size is None and
self.adjustment is None)
# TODO(chrisying): fused batch norm is currently not supported for