aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-24 15:48:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 15:52:11 -0700
commit5b9cdb2dcbb057701b0ffb0ec4e0ab555a53390b (patch)
treee389fb8b6ed641d471d74ea0af2e18a1d73a64ba
parent64ba163dc8fa1bdf780cbbb67811f9adce05e325 (diff)
Adding the batch norm adjustment to contrib/layers.
PiperOrigin-RevId: 173324074
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py20
2 files changed, 38 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 29ab281b1a..deeafdf300 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -463,7 +463,8 @@ def batch_norm(inputs,
scope=None,
renorm=False,
renorm_clipping=None,
- renorm_decay=0.99):
+ renorm_decay=0.99,
+ adjustment=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@@ -546,6 +547,17 @@ def batch_norm(inputs,
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `decay` is still applied
to get the means and variances for inference.
+ adjustment: A function taking the `Tensor` containing the (dynamic) shape of
+ the input tensor and returning a pair (scale, bias) to apply to the
+ normalized values (before gamma and beta), only during training. For
+ example,
+ `adjustment = lambda shape: (
+ tf.random_uniform(shape[-1:], 0.93, 1.07),
+ tf.random_uniform(shape[-1:], -0.1, 0.1))`
+ will scale the normalized value by up to 7% up or down, then shift the
+ result by up to 0.1 (with independent scaling and bias for each feature
+ but shared across all examples), and finally apply gamma and/or beta. If
+ `None`, no adjustment is applied.
Returns:
A `Tensor` representing the output of the operation.
@@ -569,7 +581,10 @@ def batch_norm(inputs,
# implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
- possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4]
+ possible_to_fuse = (batch_weights is None and
+ not renorm and
+ rank in [2, 4] and
+ adjustment is None)
if fused and possible_to_fuse and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
@@ -636,6 +651,7 @@ def batch_norm(inputs,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_decay,
+ adjustment=adjustment,
name=sc.name,
_scope=sc,
_reuse=reuse,
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 1040ad3ca7..7c77e905f7 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -2644,6 +2644,26 @@ class BatchNormTest(test.TestCase):
zero_debias_moving_mean=True)
sess.run(variables_lib.global_variables_initializer())
+ def testAdjustmentCreated(self):
+ # Tests that the adjustment is appropriately passed to and used by the core
+ # BN layer.
+ all_adjustments = []
+ def _create_adjustment(shape):
+ adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])]
+ all_adjustments.extend(adjustments)
+ return adjustments
+ depth = 8
+ images = array_ops.zeros([10, 5, 5, depth])
+ output = _layers.batch_norm(
+ images,
+ is_training=True,
+ adjustment=_create_adjustment)
+ self.assertListEqual(output.shape.as_list(), images.shape.as_list())
+ self.assertEqual(len(all_adjustments), 2)
+ self.assertListEqual(all_adjustments[0].shape.as_list(), [depth])
+ self.assertListEqual(all_adjustments[1].shape.as_list(), [depth])
+
+
class LayerNormTest(test.TestCase):
def testUnknownShape(self):